Skip to main content

agent_sdk/
primitive_tools.rs

1//! Primitive tools that work with the Environment abstraction.
2//!
3//! These tools provide basic file and command operations:
4//! - `ReadTool` - Read file contents
5//! - `WriteTool` - Write/create files
6//! - `EditTool` - Edit existing files with string replacement
7//! - `GlobTool` - Find files by pattern
8//! - `GrepTool` - Search file contents
9//! - `BashTool` - Execute shell commands
10//!
11//! All tools respect `AgentCapabilities` for security.
12
13mod bash;
14mod edit;
15mod glob;
16mod grep;
17mod read;
18mod write;
19
20pub use bash::BashTool;
21pub use edit::EditTool;
22pub use glob::GlobTool;
23pub use grep::GrepTool;
24pub use read::ReadTool;
25pub use write::WriteTool;
26
27use crate::{AgentCapabilities, Environment};
28use serde::Deserialize;
29use serde::de::{self, Deserializer};
30use std::fmt::Display;
31use std::str::FromStr;
32use std::sync::Arc;
33
34/// Context for primitive tools that need environment access
35pub struct PrimitiveToolContext<E: Environment> {
36    pub environment: Arc<E>,
37    pub capabilities: AgentCapabilities,
38}
39
40impl<E: Environment> PrimitiveToolContext<E> {
41    #[must_use]
42    pub const fn new(environment: Arc<E>, capabilities: AgentCapabilities) -> Self {
43        Self {
44            environment,
45            capabilities,
46        }
47    }
48}
49
50impl<E: Environment> Clone for PrimitiveToolContext<E> {
51    fn clone(&self) -> Self {
52        Self {
53            environment: Arc::clone(&self.environment),
54            capabilities: self.capabilities.clone(),
55        }
56    }
57}
58
59#[derive(Deserialize)]
60#[serde(untagged)]
61enum StringOrU64 {
62    Number(u64),
63    String(String),
64}
65
66#[derive(Deserialize)]
67#[serde(untagged)]
68enum StringOrUsize {
69    Number(usize),
70    String(String),
71}
72
73fn parse_numeric_string<T>(value: &str) -> Result<T, String>
74where
75    T: FromStr,
76    T::Err: Display,
77{
78    value
79        .trim()
80        .parse::<T>()
81        .map_err(|error| format!("invalid numeric string '{value}': {error}"))
82}
83
84pub(super) fn deserialize_optional_u64_from_string_or_int<'de, D>(
85    deserializer: D,
86) -> Result<Option<u64>, D::Error>
87where
88    D: Deserializer<'de>,
89{
90    match Option::<StringOrU64>::deserialize(deserializer)? {
91        None => Ok(None),
92        Some(StringOrU64::Number(value)) => Ok(Some(value)),
93        Some(StringOrU64::String(value)) => parse_numeric_string(&value)
94            .map(Some)
95            .map_err(de::Error::custom),
96    }
97}
98
99/// Truncate a string to at most `max_bytes` without splitting a multi-byte
100/// UTF-8 character. Returns the original string when it already fits.
101pub(crate) fn truncate_str(s: &str, max_bytes: usize) -> &str {
102    if s.len() <= max_bytes {
103        return s;
104    }
105    let mut end = max_bytes;
106    while end > 0 && !s.is_char_boundary(end) {
107        end -= 1;
108    }
109    &s[..end]
110}
111
112pub(super) fn deserialize_usize_from_string_or_int<'de, D>(
113    deserializer: D,
114) -> Result<usize, D::Error>
115where
116    D: Deserializer<'de>,
117{
118    match StringOrUsize::deserialize(deserializer)? {
119        StringOrUsize::Number(value) => Ok(value),
120        StringOrUsize::String(value) => parse_numeric_string(&value).map_err(de::Error::custom),
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::truncate_str;
127
128    #[test]
129    fn test_truncate_str_ascii_fits() {
130        assert_eq!(truncate_str("hello", 10), "hello");
131    }
132
133    #[test]
134    fn test_truncate_str_ascii_exact() {
135        assert_eq!(truncate_str("hello", 5), "hello");
136    }
137
138    #[test]
139    fn test_truncate_str_ascii_truncated() {
140        assert_eq!(truncate_str("hello world", 5), "hello");
141    }
142
143    #[test]
144    fn test_truncate_str_multibyte_emoji() {
145        let s = "Hello 🎉 world";
146        // "Hello " is 6 bytes, emoji is 4 bytes, so cutting at 8 would
147        // land inside the emoji. The helper must back up to byte 6.
148        let result = truncate_str(s, 8);
149        assert_eq!(result, "Hello ");
150    }
151
152    #[test]
153    fn test_truncate_str_cjk() {
154        let s = "漢字テスト";
155        // Each CJK char is 3 bytes. Truncating at 7 should give 2 chars (6 bytes).
156        let result = truncate_str(s, 7);
157        assert_eq!(result, "漢字");
158    }
159
160    #[test]
161    fn test_truncate_str_zero_max() {
162        assert_eq!(truncate_str("hello", 0), "");
163    }
164
165    #[test]
166    fn test_truncate_str_empty() {
167        assert_eq!(truncate_str("", 10), "");
168    }
169}