Skip to main content

lutum_protocol/
toolset.rs

1use std::fmt;
2
3use schemars::{JsonSchema, Schema, schema_for};
4use serde::{Serialize, de::DeserializeOwned};
5use thiserror::Error;
6
7use crate::conversation::{ToolMetadata, ToolUse};
8
9#[derive(Clone, Copy)]
10pub struct ToolDef {
11    pub name: &'static str,
12    pub description: &'static str,
13    schema: fn() -> Schema,
14}
15
16impl fmt::Debug for ToolDef {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        f.debug_struct("ToolDef")
19            .field("name", &self.name)
20            .field("description", &self.description)
21            .finish_non_exhaustive()
22    }
23}
24
25impl ToolDef {
26    pub const fn new(
27        name: &'static str,
28        description: &'static str,
29        schema: fn() -> Schema,
30    ) -> Self {
31        Self {
32            name,
33            description,
34            schema,
35        }
36    }
37
38    pub fn input_schema(&self) -> Schema {
39        (self.schema)()
40    }
41
42    pub fn for_input<Input>() -> Self
43    where
44        Input: ToolInput,
45    {
46        Self::new(Input::NAME, Input::DESCRIPTION, || schema_for!(Input))
47    }
48}
49
50#[derive(Debug, Error)]
51pub enum ToolCallError {
52    #[error("unknown tool `{name}`")]
53    UnknownTool { name: String },
54    #[error("failed to deserialize tool call for `{name}`: {source}")]
55    Deserialize {
56        name: String,
57        #[source]
58        source: serde_json::Error,
59    },
60}
61
62#[derive(Debug, Error)]
63pub enum ToolUseError {
64    #[error("tool metadata for `{actual}` does not match expected tool `{expected}`")]
65    MismatchedToolName {
66        expected: &'static str,
67        actual: String,
68    },
69    #[error("failed to serialize tool output: {0}")]
70    Serialize(#[from] serde_json::Error),
71}
72
73#[derive(Debug, Error)]
74pub enum ToolExecutionError<E> {
75    #[error("tool execution failed: {0}")]
76    Execute(E),
77    #[error("failed to build tool use: {0}")]
78    ToolUse(#[from] ToolUseError),
79}
80
81pub trait ToolInput:
82    Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static
83{
84    type Output: Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static;
85
86    const NAME: &'static str;
87    const DESCRIPTION: &'static str;
88
89    fn tool_use(metadata: ToolMetadata, output: Self::Output) -> Result<ToolUse, ToolUseError> {
90        if metadata.name.as_str() != Self::NAME {
91            return Err(ToolUseError::MismatchedToolName {
92                expected: Self::NAME,
93                actual: metadata.name.as_str().to_string(),
94            });
95        }
96        let result = crate::conversation::RawJson::from_serializable(&output)?;
97        Ok(metadata.into_tool_use(result))
98    }
99}
100
101pub trait ToolCallWrapper {
102    fn metadata(&self) -> &ToolMetadata;
103}
104
105impl ToolCallWrapper for std::convert::Infallible {
106    fn metadata(&self) -> &ToolMetadata {
107        match *self {}
108    }
109}
110
111pub trait ToolSelector<T: ?Sized>:
112    Copy
113    + Clone
114    + fmt::Debug
115    + Eq
116    + PartialEq
117    + std::hash::Hash
118    + Serialize
119    + DeserializeOwned
120    + JsonSchema
121    + Send
122    + Sync
123    + 'static
124{
125    fn name(self) -> &'static str;
126
127    fn definition(self) -> &'static ToolDef;
128
129    fn all() -> &'static [Self];
130
131    fn try_from_name(name: &str) -> Option<Self>;
132}
133
134pub trait Toolset: Send + Sync + 'static {
135    type ToolCall: ToolCallWrapper + Clone + fmt::Debug + Eq + PartialEq + Send + Sync + 'static;
136    type Selector: ToolSelector<Self>;
137
138    fn definitions() -> &'static [ToolDef];
139
140    fn definitions_for<I>(selectors: I) -> Vec<&'static ToolDef>
141    where
142        I: IntoIterator<Item = Self::Selector>,
143    {
144        selectors
145            .into_iter()
146            .map(|selector| selector.definition())
147            .collect()
148    }
149
150    fn parse_tool_call(metadata: ToolMetadata) -> Result<Self::ToolCall, ToolCallError>;
151}
152
153#[derive(Clone, Debug, Eq, PartialEq, Default)]
154pub enum ToolPolicy<T: Toolset> {
155    #[default]
156    Disabled,
157    AllowAll,
158    AllowOnly(Vec<T::Selector>),
159    RequireAll,
160    RequireOnly(Vec<T::Selector>),
161}
162
163impl<T> ToolPolicy<T>
164where
165    T: Toolset,
166{
167    pub fn allow_only(selectors: impl IntoIterator<Item = T::Selector>) -> Self {
168        let selectors = selectors.into_iter().collect::<Vec<_>>();
169        if selectors.is_empty() {
170            Self::Disabled
171        } else {
172            Self::AllowOnly(selectors)
173        }
174    }
175
176    pub fn require_only(selectors: impl IntoIterator<Item = T::Selector>) -> Self {
177        let selectors = selectors.into_iter().collect::<Vec<_>>();
178        if selectors.is_empty() {
179            Self::Disabled
180        } else {
181            Self::RequireOnly(selectors)
182        }
183    }
184
185    pub fn uses_tools(&self) -> bool {
186        !matches!(self, Self::Disabled)
187    }
188
189    pub fn requires_tools(&self) -> bool {
190        matches!(self, Self::RequireAll | Self::RequireOnly(_))
191    }
192
193    pub fn selected(&self) -> Option<&[T::Selector]> {
194        match self {
195            Self::AllowOnly(selectors) | Self::RequireOnly(selectors) => Some(selectors.as_slice()),
196            _ => None,
197        }
198    }
199}
200
201#[derive(
202    Clone,
203    Copy,
204    Debug,
205    Eq,
206    PartialEq,
207    Hash,
208    serde::Serialize,
209    serde::Deserialize,
210    schemars::JsonSchema,
211)]
212pub enum NoToolSelector {}
213
214impl ToolSelector<NoTools> for NoToolSelector {
215    fn name(self) -> &'static str {
216        match self {}
217    }
218
219    fn definition(self) -> &'static ToolDef {
220        match self {}
221    }
222
223    fn all() -> &'static [Self] {
224        &[]
225    }
226
227    fn try_from_name(_name: &str) -> Option<Self> {
228        None
229    }
230}
231
232#[derive(Clone, Copy, Debug, Default)]
233pub struct NoTools;
234
235impl Toolset for NoTools {
236    type ToolCall = std::convert::Infallible;
237    type Selector = NoToolSelector;
238
239    fn definitions() -> &'static [ToolDef] {
240        &[]
241    }
242
243    fn parse_tool_call(metadata: ToolMetadata) -> Result<Self::ToolCall, ToolCallError> {
244        Err(ToolCallError::UnknownTool {
245            name: metadata.name.as_str().to_string(),
246        })
247    }
248}