lutum_protocol/
toolset.rs1use std::fmt;
2
3use schemars::{JsonSchema, Schema, schema_for};
4use serde::{Serialize, de::DeserializeOwned};
5use thiserror::Error;
6
7use crate::conversation::{ToolMetadata, ToolUse};
8
9#[derive(Clone, Copy)]
10pub struct ToolDef {
11 pub name: &'static str,
12 pub description: &'static str,
13 schema: fn() -> Schema,
14}
15
16impl fmt::Debug for ToolDef {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 f.debug_struct("ToolDef")
19 .field("name", &self.name)
20 .field("description", &self.description)
21 .finish_non_exhaustive()
22 }
23}
24
25impl ToolDef {
26 pub const fn new(
27 name: &'static str,
28 description: &'static str,
29 schema: fn() -> Schema,
30 ) -> Self {
31 Self {
32 name,
33 description,
34 schema,
35 }
36 }
37
38 pub fn input_schema(&self) -> Schema {
39 (self.schema)()
40 }
41
42 pub fn for_input<Input>() -> Self
43 where
44 Input: ToolInput,
45 {
46 Self::new(Input::NAME, Input::DESCRIPTION, || schema_for!(Input))
47 }
48}
49
50#[derive(Debug, Error)]
51pub enum ToolCallError {
52 #[error("unknown tool `{name}`")]
53 UnknownTool { name: String },
54 #[error("failed to deserialize tool call for `{name}`: {source}")]
55 Deserialize {
56 name: String,
57 #[source]
58 source: serde_json::Error,
59 },
60}
61
62#[derive(Debug, Error)]
63pub enum ToolUseError {
64 #[error("tool metadata for `{actual}` does not match expected tool `{expected}`")]
65 MismatchedToolName {
66 expected: &'static str,
67 actual: String,
68 },
69 #[error("failed to serialize tool output: {0}")]
70 Serialize(#[from] serde_json::Error),
71}
72
73#[derive(Debug, Error)]
74pub enum ToolExecutionError<E> {
75 #[error("tool execution failed: {0}")]
76 Execute(E),
77 #[error("failed to build tool use: {0}")]
78 ToolUse(#[from] ToolUseError),
79}
80
81pub trait ToolInput:
82 Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static
83{
84 type Output: Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static;
85
86 const NAME: &'static str;
87 const DESCRIPTION: &'static str;
88
89 fn tool_use(metadata: ToolMetadata, output: Self::Output) -> Result<ToolUse, ToolUseError> {
90 if metadata.name.as_str() != Self::NAME {
91 return Err(ToolUseError::MismatchedToolName {
92 expected: Self::NAME,
93 actual: metadata.name.as_str().to_string(),
94 });
95 }
96 let result = crate::conversation::RawJson::from_serializable(&output)?;
97 Ok(metadata.into_tool_use(result))
98 }
99}
100
101pub trait ToolCallWrapper {
102 fn metadata(&self) -> &ToolMetadata;
103}
104
105impl ToolCallWrapper for std::convert::Infallible {
106 fn metadata(&self) -> &ToolMetadata {
107 match *self {}
108 }
109}
110
111pub trait ToolSelector<T: ?Sized>:
112 Copy
113 + Clone
114 + fmt::Debug
115 + Eq
116 + PartialEq
117 + std::hash::Hash
118 + Serialize
119 + DeserializeOwned
120 + JsonSchema
121 + Send
122 + Sync
123 + 'static
124{
125 fn name(self) -> &'static str;
126
127 fn definition(self) -> &'static ToolDef;
128
129 fn all() -> &'static [Self];
130
131 fn try_from_name(name: &str) -> Option<Self>;
132}
133
134pub trait Toolset: Send + Sync + 'static {
135 type ToolCall: ToolCallWrapper + Clone + fmt::Debug + Eq + PartialEq + Send + Sync + 'static;
136 type Selector: ToolSelector<Self>;
137
138 fn definitions() -> &'static [ToolDef];
139
140 fn definitions_for<I>(selectors: I) -> Vec<&'static ToolDef>
141 where
142 I: IntoIterator<Item = Self::Selector>,
143 {
144 selectors
145 .into_iter()
146 .map(|selector| selector.definition())
147 .collect()
148 }
149
150 fn parse_tool_call(metadata: ToolMetadata) -> Result<Self::ToolCall, ToolCallError>;
151}
152
153#[derive(Clone, Debug, Eq, PartialEq, Default)]
154pub enum ToolPolicy<T: Toolset> {
155 #[default]
156 Disabled,
157 AllowAll,
158 AllowOnly(Vec<T::Selector>),
159 RequireAll,
160 RequireOnly(Vec<T::Selector>),
161}
162
163impl<T> ToolPolicy<T>
164where
165 T: Toolset,
166{
167 pub fn allow_only(selectors: impl IntoIterator<Item = T::Selector>) -> Self {
168 let selectors = selectors.into_iter().collect::<Vec<_>>();
169 if selectors.is_empty() {
170 Self::Disabled
171 } else {
172 Self::AllowOnly(selectors)
173 }
174 }
175
176 pub fn require_only(selectors: impl IntoIterator<Item = T::Selector>) -> Self {
177 let selectors = selectors.into_iter().collect::<Vec<_>>();
178 if selectors.is_empty() {
179 Self::Disabled
180 } else {
181 Self::RequireOnly(selectors)
182 }
183 }
184
185 pub fn uses_tools(&self) -> bool {
186 !matches!(self, Self::Disabled)
187 }
188
189 pub fn requires_tools(&self) -> bool {
190 matches!(self, Self::RequireAll | Self::RequireOnly(_))
191 }
192
193 pub fn selected(&self) -> Option<&[T::Selector]> {
194 match self {
195 Self::AllowOnly(selectors) | Self::RequireOnly(selectors) => Some(selectors.as_slice()),
196 _ => None,
197 }
198 }
199}
200
201#[derive(
202 Clone,
203 Copy,
204 Debug,
205 Eq,
206 PartialEq,
207 Hash,
208 serde::Serialize,
209 serde::Deserialize,
210 schemars::JsonSchema,
211)]
212pub enum NoToolSelector {}
213
214impl ToolSelector<NoTools> for NoToolSelector {
215 fn name(self) -> &'static str {
216 match self {}
217 }
218
219 fn definition(self) -> &'static ToolDef {
220 match self {}
221 }
222
223 fn all() -> &'static [Self] {
224 &[]
225 }
226
227 fn try_from_name(_name: &str) -> Option<Self> {
228 None
229 }
230}
231
232#[derive(Clone, Copy, Debug, Default)]
233pub struct NoTools;
234
235impl Toolset for NoTools {
236 type ToolCall = std::convert::Infallible;
237 type Selector = NoToolSelector;
238
239 fn definitions() -> &'static [ToolDef] {
240 &[]
241 }
242
243 fn parse_tool_call(metadata: ToolMetadata) -> Result<Self::ToolCall, ToolCallError> {
244 Err(ToolCallError::UnknownTool {
245 name: metadata.name.as_str().to_string(),
246 })
247 }
248}