agent_sdk/
primitive_tools.rs1mod bash;
14mod edit;
15mod glob;
16mod grep;
17mod read;
18mod write;
19
20pub use bash::BashTool;
21pub use edit::EditTool;
22pub use glob::GlobTool;
23pub use grep::GrepTool;
24pub use read::ReadTool;
25pub use write::WriteTool;
26
27use crate::{AgentCapabilities, Environment};
28use serde::Deserialize;
29use serde::de::{self, Deserializer};
30use std::fmt::Display;
31use std::str::FromStr;
32use std::sync::Arc;
33
34pub struct PrimitiveToolContext<E: Environment> {
36 pub environment: Arc<E>,
37 pub capabilities: AgentCapabilities,
38}
39
40impl<E: Environment> PrimitiveToolContext<E> {
41 #[must_use]
42 pub const fn new(environment: Arc<E>, capabilities: AgentCapabilities) -> Self {
43 Self {
44 environment,
45 capabilities,
46 }
47 }
48}
49
50impl<E: Environment> Clone for PrimitiveToolContext<E> {
51 fn clone(&self) -> Self {
52 Self {
53 environment: Arc::clone(&self.environment),
54 capabilities: self.capabilities.clone(),
55 }
56 }
57}
58
59#[derive(Deserialize)]
60#[serde(untagged)]
61enum StringOrU64 {
62 Number(u64),
63 String(String),
64}
65
66#[derive(Deserialize)]
67#[serde(untagged)]
68enum StringOrUsize {
69 Number(usize),
70 String(String),
71}
72
73fn parse_numeric_string<T>(value: &str) -> Result<T, String>
74where
75 T: FromStr,
76 T::Err: Display,
77{
78 value
79 .trim()
80 .parse::<T>()
81 .map_err(|error| format!("invalid numeric string '{value}': {error}"))
82}
83
84pub(super) fn deserialize_optional_u64_from_string_or_int<'de, D>(
85 deserializer: D,
86) -> Result<Option<u64>, D::Error>
87where
88 D: Deserializer<'de>,
89{
90 match Option::<StringOrU64>::deserialize(deserializer)? {
91 None => Ok(None),
92 Some(StringOrU64::Number(value)) => Ok(Some(value)),
93 Some(StringOrU64::String(value)) => parse_numeric_string(&value)
94 .map(Some)
95 .map_err(de::Error::custom),
96 }
97}
98
99pub(crate) fn truncate_str(s: &str, max_bytes: usize) -> &str {
102 if s.len() <= max_bytes {
103 return s;
104 }
105 let mut end = max_bytes;
106 while end > 0 && !s.is_char_boundary(end) {
107 end -= 1;
108 }
109 &s[..end]
110}
111
112pub(super) fn deserialize_usize_from_string_or_int<'de, D>(
113 deserializer: D,
114) -> Result<usize, D::Error>
115where
116 D: Deserializer<'de>,
117{
118 match StringOrUsize::deserialize(deserializer)? {
119 StringOrUsize::Number(value) => Ok(value),
120 StringOrUsize::String(value) => parse_numeric_string(&value).map_err(de::Error::custom),
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::truncate_str;
127
128 #[test]
129 fn test_truncate_str_ascii_fits() {
130 assert_eq!(truncate_str("hello", 10), "hello");
131 }
132
133 #[test]
134 fn test_truncate_str_ascii_exact() {
135 assert_eq!(truncate_str("hello", 5), "hello");
136 }
137
138 #[test]
139 fn test_truncate_str_ascii_truncated() {
140 assert_eq!(truncate_str("hello world", 5), "hello");
141 }
142
143 #[test]
144 fn test_truncate_str_multibyte_emoji() {
145 let s = "Hello 🎉 world";
146 let result = truncate_str(s, 8);
149 assert_eq!(result, "Hello ");
150 }
151
152 #[test]
153 fn test_truncate_str_cjk() {
154 let s = "æ¼¢å—テスト";
155 let result = truncate_str(s, 7);
157 assert_eq!(result, "æ¼¢å—");
158 }
159
160 #[test]
161 fn test_truncate_str_zero_max() {
162 assert_eq!(truncate_str("hello", 0), "");
163 }
164
165 #[test]
166 fn test_truncate_str_empty() {
167 assert_eq!(truncate_str("", 10), "");
168 }
169}