heartbit_core/tool/builtins/
mod.rs1#![allow(missing_docs)]
4mod bash;
5mod edit;
6mod file_tracker;
7mod glob;
8mod grep;
9mod image_generate;
10mod list;
11mod patch;
12mod question;
13mod read;
14mod skill;
15mod todo;
16mod tts;
17pub(crate) mod twitter_post;
18mod webfetch;
19mod websearch;
20mod write;
21
22use std::path::PathBuf;
23use std::sync::Arc;
24
25use crate::tool::Tool;
26
27fn is_protected(path: &std::path::Path, protected: &[PathBuf]) -> bool {
36 let normalized = crate::workspace::normalize_path(path);
37 for pp in protected {
38 if normalized.starts_with(pp) || normalized == *pp {
39 return true;
40 }
41 if let Some(pattern) = pp.to_str()
42 && let Some(pat_ext) = pattern.strip_prefix("*.")
43 && let Some(ext) = normalized.extension().and_then(|e| e.to_str())
44 && ext.eq_ignore_ascii_case(pat_ext)
45 {
46 return true;
47 }
48 }
49 false
50}
51
52pub(crate) async fn write_no_follow(path: &std::path::Path, bytes: &[u8]) -> std::io::Result<()> {
60 #[cfg(unix)]
61 {
62 use std::io::Write;
63 use std::os::unix::fs::OpenOptionsExt;
64 let path_owned = path.to_path_buf();
65 let bytes = bytes.to_vec();
66 tokio::task::spawn_blocking(move || -> std::io::Result<()> {
67 let mut file = std::fs::OpenOptions::new()
68 .write(true)
69 .create(true)
70 .truncate(true)
71 .custom_flags(libc::O_NOFOLLOW)
72 .open(&path_owned)?;
73 file.write_all(&bytes)?;
74 file.sync_all()?;
75 Ok(())
76 })
77 .await
78 .map_err(|e| std::io::Error::other(format!("spawn_blocking failed: {e}")))?
79 }
80 #[cfg(not(unix))]
81 {
82 tokio::fs::write(path, bytes).await
87 }
88}
89
90pub(crate) fn resolve_path(
92 path: &str,
93 workspace: Option<&std::path::Path>,
94 protected_paths: &[PathBuf],
95) -> Result<PathBuf, String> {
96 let p = std::path::Path::new(path);
97
98 match workspace {
99 Some(ws) => {
100 if p.is_absolute() {
101 return Err(format!(
102 "Absolute paths are not allowed when workspace is set. \
103 Use a relative path instead of '{path}'."
104 ));
105 }
106 let candidate = ws.join(p);
107 let normalized = crate::workspace::normalize_path(&candidate);
108 if !normalized.starts_with(ws) {
109 return Err(format!(
110 "Path '{path}' escapes the workspace root ({}).",
111 ws.display()
112 ));
113 }
114 if let Ok(canonical) = normalized.canonicalize()
115 && !canonical.starts_with(ws)
116 {
117 return Err(format!(
118 "Path '{path}' resolves to {} which is outside the workspace.",
119 canonical.display()
120 ));
121 }
122 if is_protected(&normalized, protected_paths) {
123 return Err(format!("Access to '{path}' is denied (protected path)."));
124 }
125 Ok(normalized)
126 }
127 None => {
128 let result = p.to_path_buf();
129 if is_protected(&result, protected_paths) {
130 return Err(format!("Access to '{path}' is denied (protected path)."));
131 }
132 Ok(result)
133 }
134 }
135}
136
137pub fn floor_char_boundary(text: &str, target: usize) -> usize {
138 let mut pos = target.min(text.len());
139 while pos > 0 && !text.is_char_boundary(pos) {
140 pos -= 1;
141 }
142 pos
143}
144
145pub use file_tracker::FileTracker;
146pub use question::{
147 OnQuestion, Question, QuestionOption, QuestionRequest, QuestionResponse, QuestionTool,
148};
149pub use todo::{TodoPriority, TodoStatus, TodoStore};
150pub use twitter_post::TwitterCredentials;
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum ToolRisk {
155 Safe,
157 Dangerous,
159}
160
161#[non_exhaustive]
163pub struct BuiltinToolsConfig {
164 pub file_tracker: Arc<FileTracker>,
165 pub todo_store: Arc<TodoStore>,
166 pub on_question: Option<Arc<OnQuestion>>,
167 pub workspace: Option<PathBuf>,
168 pub dangerous_tools: bool,
170 pub env_policy: crate::workspace::EnvPolicy,
172 pub protected_paths: Vec<PathBuf>,
174 #[cfg(all(target_os = "linux", feature = "sandbox"))]
176 pub sandbox_policy: Option<crate::sandbox::SandboxPolicy>,
177 pub twitter_credentials: Option<TwitterCredentials>,
179 pub allowlist: Option<Vec<String>>,
183 pub path_policy: Option<Arc<crate::sandbox::CorePathPolicy>>,
188}
189
190pub fn default_protected_paths() -> Vec<PathBuf> {
198 let mut v: Vec<PathBuf> = vec![
199 PathBuf::from("*.env"),
202 PathBuf::from("*.pem"),
203 PathBuf::from("*.key"),
204 PathBuf::from("*.p12"),
205 PathBuf::from("*.pfx"),
206 PathBuf::from("*.kdbx"),
207 PathBuf::from("/etc/shadow"),
209 PathBuf::from("/etc/sudoers"),
210 PathBuf::from("/proc/self/environ"),
211 ];
212 if let Some(home) = std::env::var_os("HOME") {
213 let h = PathBuf::from(home);
214 v.push(h.join(".ssh"));
215 v.push(h.join(".aws"));
216 v.push(h.join(".gnupg"));
217 v.push(h.join(".config").join("heartbit"));
218 v.push(h.join(".docker").join("config.json"));
219 v.push(h.join(".netrc"));
220 }
221 v
222}
223
224impl Default for BuiltinToolsConfig {
225 fn default() -> Self {
226 Self {
227 file_tracker: Arc::new(FileTracker::new()),
228 todo_store: Arc::new(TodoStore::new()),
229 on_question: None,
230 workspace: None,
231 dangerous_tools: false,
232 env_policy: crate::workspace::EnvPolicy::Inherit,
233 protected_paths: default_protected_paths(),
235 #[cfg(all(target_os = "linux", feature = "sandbox"))]
236 sandbox_policy: None,
237 twitter_credentials: None,
238 allowlist: None,
239 path_policy: None,
240 }
241 }
242}
243
244pub fn builtin_tools(config: BuiltinToolsConfig) -> Vec<Arc<dyn Tool>> {
246 let ws = config.workspace.map(|w| w.canonicalize().unwrap_or(w));
247 let pp = Arc::new(config.protected_paths);
248 let path_policy = config.path_policy;
249 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
250
251 macro_rules! maybe_policy {
252 ($tool:expr) => {
253 if let Some(ref pp) = path_policy {
254 $tool.with_path_policy(Arc::clone(pp))
255 } else {
256 $tool
257 }
258 };
259 }
260
261 if config.dangerous_tools {
262 let bash_tool: Arc<dyn Tool> = match &ws {
263 Some(path) => {
264 let tool = bash::BashTool::with_sandbox(path.clone(), config.env_policy);
265 #[cfg(all(target_os = "linux", feature = "sandbox"))]
266 let tool = if let Some(policy) = config.sandbox_policy {
267 tool.with_sandbox_policy(policy)
268 } else {
269 tool
270 };
271 Arc::new(maybe_policy!(tool))
272 }
273 None => Arc::new(maybe_policy!(bash::BashTool::new())),
274 };
275 tools.push(bash_tool);
276 }
277
278 tools.extend([
279 Arc::new(maybe_policy!(read::ReadTool::new(
280 config.file_tracker.clone(),
281 ws.clone(),
282 Arc::clone(&pp),
283 ))) as Arc<dyn Tool>,
284 Arc::new(maybe_policy!(write::WriteTool::new(
285 config.file_tracker.clone(),
286 ws.clone(),
287 Arc::clone(&pp),
288 ))),
289 Arc::new(maybe_policy!(edit::EditTool::new(
290 config.file_tracker.clone(),
291 ws.clone(),
292 Arc::clone(&pp),
293 ))),
294 Arc::new(maybe_policy!(grep::GrepTool::new(
299 ws.clone(),
300 Arc::clone(&pp)
301 ))),
302 Arc::new(maybe_policy!(glob::GlobTool::new(
303 ws.clone(),
304 Arc::clone(&pp)
305 ))),
306 Arc::new(maybe_policy!(list::ListTool::new(
307 ws.clone(),
308 Arc::clone(&pp)
309 ))),
310 Arc::new(maybe_policy!(patch::PatchTool::new(
311 config.file_tracker.clone(),
312 ws,
313 Arc::clone(&pp),
314 ))),
315 Arc::new(webfetch::WebFetchTool::new()),
316 Arc::new(websearch::WebSearchTool::new()),
317 Arc::new(image_generate::ImageGenerateTool::new()),
318 Arc::new(tts::TtsTool::new()),
319 Arc::new(skill::SkillTool::new()),
320 ]);
321
322 let todo_tools = todo::todo_tools(config.todo_store);
323 tools.extend(todo_tools);
324
325 if let Some(on_question) = config.on_question {
326 tools.push(Arc::new(question::QuestionTool::new(on_question)));
327 }
328
329 if let Some(creds) = config.twitter_credentials {
330 tools.push(Arc::new(twitter_post::TwitterPostTool::new(creds)));
331 }
332
333 if let Some(ref allowed) = config.allowlist {
334 let set: std::collections::HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect();
335 tools.retain(|t| set.contains(t.definition().name.as_str()));
336 }
337
338 tools
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn floor_char_boundary_ascii() {
347 assert_eq!(floor_char_boundary("hello", 3), 3);
348 assert_eq!(floor_char_boundary("hello", 10), 5);
349 assert_eq!(floor_char_boundary("hello", 0), 0);
350 }
351
352 #[test]
353 fn floor_char_boundary_multibyte() {
354 let s = "café";
355 assert_eq!(s.len(), 5);
356 assert_eq!(floor_char_boundary(s, 4), 3);
357 assert_eq!(floor_char_boundary(s, 3), 3);
358 assert_eq!(floor_char_boundary(s, 5), 5);
359 }
360
361 #[test]
362 fn resolve_path_absolute_rejected_with_workspace() {
363 let dir = tempfile::tempdir().unwrap();
364 let ws = dir.path();
365 let result = resolve_path("/absolute/path", Some(ws), &[]);
366 assert!(result.is_err());
367 assert!(
368 result
369 .unwrap_err()
370 .contains("Absolute paths are not allowed")
371 );
372 }
373
374 #[test]
375 fn resolve_path_absolute_passthrough_without_workspace() {
376 let result = resolve_path("/absolute/path", None, &[]);
377 assert_eq!(result.unwrap(), PathBuf::from("/absolute/path"));
378 }
379
380 #[test]
381 fn resolve_path_relative_with_workspace() {
382 let dir = tempfile::tempdir().unwrap();
383 let ws = dir.path().canonicalize().unwrap();
384 let result = resolve_path("notes.md", Some(&ws), &[]);
385 assert_eq!(result.unwrap(), ws.join("notes.md"));
386 }
387
388 #[test]
389 fn resolve_path_relative_without_workspace() {
390 let result = resolve_path("notes.md", None, &[]);
391 assert_eq!(result.unwrap(), PathBuf::from("notes.md"));
392 }
393
394 #[test]
395 fn resolve_path_traversal_rejected() {
396 let dir = tempfile::tempdir().unwrap();
397 let ws = dir.path().canonicalize().unwrap();
398 let result = resolve_path("../../etc/passwd", Some(&ws), &[]);
399 assert!(result.is_err());
400 assert!(result.unwrap_err().contains("escapes the workspace"));
401 }
402
403 #[test]
404 fn resolve_path_internal_dotdot_allowed() {
405 let dir = tempfile::tempdir().unwrap();
406 let ws = dir.path().canonicalize().unwrap();
407 let result = resolve_path("sub/../file.txt", Some(&ws), &[]);
408 assert_eq!(result.unwrap(), ws.join("file.txt"));
409 }
410
411 #[test]
412 fn resolve_path_boundary_dotdot_rejected() {
413 let dir = tempfile::tempdir().unwrap();
414 let ws = dir.path().canonicalize().unwrap();
415 let result = resolve_path("../escape", Some(&ws), &[]);
416 assert!(result.is_err());
417 }
418
419 #[test]
420 fn resolve_path_symlink_escape_rejected() {
421 let dir = tempfile::tempdir().unwrap();
422 let ws = dir.path().canonicalize().unwrap();
423 let target = tempfile::tempdir().unwrap();
424 std::fs::write(target.path().join("secret.txt"), "secret").unwrap();
425 let link_path = ws.join("escape_link");
426 #[cfg(unix)]
427 std::os::unix::fs::symlink(target.path(), &link_path).unwrap();
428 #[cfg(not(unix))]
429 {
430 return;
431 }
432 let result = resolve_path("escape_link/secret.txt", Some(&ws), &[]);
433 assert!(
434 result.is_err(),
435 "symlink escape should be rejected: {:?}",
436 result
437 );
438 }
439
440 #[test]
441 fn resolve_path_rejects_protected_extension() {
442 let dir = tempfile::tempdir().unwrap();
443 let ws = dir.path().canonicalize().unwrap();
444 std::fs::write(ws.join("secret.env"), "SECRET=value").unwrap();
445 let protected = vec![PathBuf::from("*.env")];
446 let result = resolve_path("secret.env", Some(&ws), &protected);
447 assert!(result.is_err());
448 assert!(result.unwrap_err().contains("protected"));
449 }
450
451 #[test]
452 fn resolve_path_allows_non_protected() {
453 let dir = tempfile::tempdir().unwrap();
454 let ws = dir.path().canonicalize().unwrap();
455 let protected = vec![PathBuf::from("*.env")];
456 let result = resolve_path("notes.md", Some(&ws), &protected);
457 assert!(result.is_ok());
458 }
459
460 #[test]
461 fn builtin_tools_excludes_bash_by_default() {
462 let tools = builtin_tools(BuiltinToolsConfig::default());
463 assert!(!tools.iter().any(|t| t.definition().name == "bash"));
464 assert_eq!(tools.len(), 14);
465 }
466
467 #[test]
468 fn builtin_tools_includes_bash_when_dangerous() {
469 let config = BuiltinToolsConfig {
470 dangerous_tools: true,
471 ..Default::default()
472 };
473 let tools = builtin_tools(config);
474 assert!(tools.iter().any(|t| t.definition().name == "bash"));
475 assert_eq!(tools.len(), 15);
476 }
477
478 #[test]
479 fn builtin_tools_with_question_callback() {
480 let config = BuiltinToolsConfig {
481 dangerous_tools: true,
482 on_question: Some(Arc::new(|_| {
483 Box::pin(async { Ok(QuestionResponse { answers: vec![] }) })
484 })),
485 ..Default::default()
486 };
487 let tools = builtin_tools(config);
488 assert_eq!(tools.len(), 16);
489 }
490
491 #[test]
492 fn builtin_tools_includes_twitter_when_credentials_present() {
493 let config = BuiltinToolsConfig {
494 twitter_credentials: Some(TwitterCredentials {
495 consumer_key: "ck".into(),
496 consumer_secret: "cs".into(),
497 access_token: "at".into(),
498 access_token_secret: "ats".into(),
499 }),
500 ..Default::default()
501 };
502 let tools = builtin_tools(config);
503 assert_eq!(tools.len(), 15); assert!(tools.iter().any(|t| t.definition().name == "twitter_post"));
505 }
506
507 #[test]
508 fn builtin_tools_excludes_twitter_when_no_credentials() {
509 let tools = builtin_tools(BuiltinToolsConfig::default());
510 assert!(!tools.iter().any(|t| t.definition().name == "twitter_post"));
511 }
512
513 #[test]
514 fn builtin_tools_with_allowlist() {
515 let config = BuiltinToolsConfig {
516 allowlist: Some(vec!["websearch".into(), "webfetch".into()]),
517 ..Default::default()
518 };
519 let tools = builtin_tools(config);
520 assert_eq!(tools.len(), 2);
521 let names: Vec<String> = tools.iter().map(|t| t.definition().name.clone()).collect();
522 assert!(names.contains(&"websearch".to_string()));
523 assert!(names.contains(&"webfetch".to_string()));
524 }
525
526 #[test]
527 fn builtin_tools_empty_allowlist() {
528 let config = BuiltinToolsConfig {
529 allowlist: Some(vec![]),
530 ..Default::default()
531 };
532 let tools = builtin_tools(config);
533 assert_eq!(tools.len(), 0);
534 }
535
536 #[test]
537 fn builtin_tools_allowlist_none_returns_all() {
538 let config = BuiltinToolsConfig {
539 allowlist: None,
540 ..Default::default()
541 };
542 let tools = builtin_tools(config);
543 assert_eq!(tools.len(), 14);
544 }
545
546 #[test]
547 fn builtin_tools_allowlist_bash_gated() {
548 let config = BuiltinToolsConfig {
550 dangerous_tools: false,
551 allowlist: Some(vec!["bash".into()]),
552 ..Default::default()
553 };
554 let tools = builtin_tools(config);
555 assert_eq!(tools.len(), 0);
556 }
557}