Skip to main content

anda_core/
tool.rs

1//! Tool traits and registries.
2//!
3//! Tools are reusable capabilities that agents can call through the runtime.
4//! This module provides:
5//! - [`Tool`] for strongly typed tool implementations.
6//! - [`DynTool`] for runtime dispatch through trait objects.
7//! - [`ToolSet`] for name-based registration and lookup.
8//!
9//! Tools define their own JSON function schema through [`FunctionDefinition`]
10//! and receive typed arguments after the runtime validates and deserializes a
11//! raw JSON call.
12
13use serde::{Deserialize, Serialize, de::DeserializeOwned};
14use std::{any::Any, collections::BTreeMap, future::Future, marker::PhantomData, sync::Arc};
15
16use crate::{
17    BoxError, BoxFut, BoxPinFut, Function, Json, Resource, ToolInput, ToolOutput,
18    context::BaseContext, model::FunctionDefinition, select_resources, validate_function_name,
19};
20
21/// Strongly typed interface for an agent tool.
22///
23/// # Type Parameters
24/// - `C`: Runtime context implementing [`BaseContext`].
25pub trait Tool<C>: Send + Sync
26where
27    C: BaseContext + Send + Sync,
28{
29    /// The arguments type of the tool.
30    type Args: DeserializeOwned + Send;
31
32    /// The output type of the tool.
33    type Output: Serialize;
34
35    /// Returns the unique tool name.
36    ///
37    /// # Rules
38    /// - Must not be empty;
39    /// - Must not exceed 64 characters;
40    /// - Must start with a lowercase letter;
41    /// - Can only contain: lowercase letters (a-z), digits (0-9), and underscores (_);
42    /// - Unique within the engine.
43    fn name(&self) -> String;
44
45    /// Returns a concise description of the tool's capability.
46    fn description(&self) -> String;
47
48    /// Returns the function definition, including the JSON parameter schema.
49    ///
50    /// # Returns
51    /// - `FunctionDefinition`: The schema definition of the tool's parameters and metadata.
52    fn definition(&self) -> FunctionDefinition;
53
54    /// Returns the capability group this tool belongs to, if any.
55    ///
56    /// Tools that form a coherent bundle (for example the filesystem workspace
57    /// tools) return the same [`ToolGroupInfo`] so the registry can present them
58    /// as one group in discovery. The default implementation returns `None`.
59    fn group(&self) -> Option<ToolGroupInfo> {
60        None
61    }
62
63    /// Returns resource tags this tool can consume.
64    ///
65    /// The default implementation returns an empty list, meaning no resources
66    /// are selected for this tool. Return `vec!["*".into()]` to accept all
67    /// attached resources.
68    ///
69    /// # Returns
70    /// Resource tags supported by this tool.
71    fn supported_resource_tags(&self) -> Vec<String> {
72        Vec::new()
73    }
74
75    /// Removes and returns resources matching this tool's supported tags.
76    fn select_resources(&self, resources: &mut Vec<Resource>) -> Vec<Resource> {
77        let supported_tags = self.supported_resource_tags();
78        select_resources(resources, &supported_tags)
79    }
80
81    /// Initializes the tool with the given context.
82    ///
83    /// Runtimes call this once while building the engine.
84    fn init(&self, _ctx: C) -> impl Future<Output = Result<(), BoxError>> + Send {
85        futures::future::ready(Ok(()))
86    }
87
88    /// Executes the tool with typed arguments and selected resources.
89    ///
90    /// # Arguments
91    /// - `ctx`: The execution context implementing [`BaseContext`].
92    /// - `args`: struct arguments for the tool.
93    /// - `resources`: Additional resources selected for this tool.
94    ///
95    /// # Returns
96    /// A future resolving to [`ToolOutput<Self::Output>`].
97    fn call(
98        &self,
99        ctx: C,
100        args: Self::Args,
101        resources: Vec<Resource>,
102    ) -> impl Future<Output = Result<ToolOutput<Self::Output>, BoxError>> + Send;
103
104    /// Executes the tool from raw JSON arguments and returns JSON output.
105    fn call_raw(
106        &self,
107        ctx: C,
108        args: Json,
109        resources: Vec<Resource>,
110    ) -> impl Future<Output = Result<ToolOutput<Json>, BoxError>> + Send {
111        async move {
112            let args: Self::Args = serde_json::from_value(args)
113                .map_err(|err| format!("tool {}, invalid args: {}", self.name(), err))?;
114            let mut result = self
115                .call(ctx, args, resources)
116                .await
117                .map_err(|err| format!("tool {}, call failed: {}", self.name(), err))?;
118            let output = serde_json::to_value(&result.output)?;
119            if result.usage.requests == 0 {
120                result.usage.requests = 1;
121            }
122
123            Ok(ToolOutput {
124                output,
125                is_error: result.is_error,
126                artifacts: result.artifacts,
127                usage: result.usage,
128                tools_usage: result.tools_usage,
129            })
130        }
131    }
132}
133
134/// Object-safe wrapper around [`Tool`] for runtime dispatch.
135///
136/// Runtime registries store tools through this trait so callers can select and
137/// execute tools by name without knowing their concrete Rust types.
138pub trait DynTool<C>: Send + Sync
139where
140    C: BaseContext + Send + Sync,
141{
142    /// Returns this tool as [`Any`] for type inspection.
143    fn as_any(&self) -> &(dyn Any + Send + Sync);
144
145    /// Converts the shared tool into [`Any`] for downcasting.
146    fn into_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
147
148    /// Returns the unique tool name.
149    fn name(&self) -> String;
150
151    /// Returns the function definition exposed to model providers.
152    fn definition(&self) -> FunctionDefinition;
153
154    /// Returns the capability group this tool belongs to, if any.
155    fn group(&self) -> Option<ToolGroupInfo> {
156        None
157    }
158
159    /// Returns resource tags this tool can consume.
160    fn supported_resource_tags(&self) -> Vec<String>;
161
162    /// Initializes the tool through object-safe dispatch.
163    fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>>;
164
165    /// Executes the tool through object-safe dispatch with raw JSON arguments.
166    fn call(
167        &self,
168        ctx: C,
169        args: Json,
170        resources: Vec<Resource>,
171    ) -> BoxPinFut<Result<ToolOutput<Json>, BoxError>>;
172}
173
174/// Group membership a single [`Tool`] declares for itself.
175///
176/// A static tool uses this to say "I belong to bundle X" without knowing the
177/// other members. The registry ([`ToolSet`]) collects every tool that declares
178/// the same `id` and assembles the full [`ToolGroup`], so the member list always
179/// reflects the tools actually registered (no stale or missing entries).
180///
181/// Share one constructor across a bundle's tools to keep the metadata identical;
182/// when ids collide, the first-registered tool's metadata wins.
183#[derive(Debug, Clone, Default, Serialize, Deserialize)]
184pub struct ToolGroupInfo {
185    /// Stable group id, unique across the engine (for example `fs_workspace`).
186    pub id: String,
187    /// Human-facing group title.
188    pub title: String,
189    /// Concise summary of what this bundle of tools does.
190    pub description: String,
191    /// Optional usage instructions describing how the member tools work
192    /// together. Reference for the model, never a runtime directive.
193    #[serde(default, skip_serializing_if = "Option::is_none")]
194    pub instructions: Option<String>,
195}
196
197/// A related set of callables surfaced together from one source.
198///
199/// A group tells the model that a bundle of tools share an origin (for example
200/// a single MCP server, or the built-in filesystem tools) and are meant to be
201/// combined to complete related work. Groups are a *discovery-layer* concept
202/// only: they are never sent to model providers as part of the function-calling
203/// schema. They are returned by the built-in discovery helpers (`tools_search` /
204/// `tools_select`) so the model can understand a bundle's purpose and pull in
205/// sibling tools as needed.
206///
207/// `instructions`, `title`, and `description` may originate from untrusted
208/// remote metadata. They are surfaced as plain data the model reads, never as
209/// system instructions, so they cannot escalate into runtime directives.
210#[derive(Debug, Clone, Default, Serialize, Deserialize)]
211pub struct ToolGroup {
212    /// Stable group id, unique across providers (for example `mcp:filesystem`).
213    pub id: String,
214    /// Human-facing group title.
215    pub title: String,
216    /// Concise summary of what this bundle of tools does.
217    pub description: String,
218    /// Optional usage instructions describing how the member tools work
219    /// together. Untrusted remote metadata; treat as reference, not directives.
220    #[serde(default, skip_serializing_if = "Option::is_none")]
221    pub instructions: Option<String>,
222    /// Model-facing names of the tools that belong to this group.
223    pub members: Vec<String>,
224}
225
226impl ToolGroup {
227    /// Builds a group from a per-tool [`ToolGroupInfo`] and resolved members.
228    pub fn from_info(info: ToolGroupInfo, members: Vec<String>) -> Self {
229        Self {
230            id: info.id,
231            title: info.title,
232            description: info.description,
233            instructions: info.instructions,
234            members,
235        }
236    }
237}
238
239/// Dynamic source of callable tools.
240///
241/// Providers are useful for integrations whose tool set is discovered at
242/// runtime, such as remote MCP servers. A provider exposes a synchronous
243/// snapshot for model-facing discovery and async methods for refresh and call
244/// execution.
245pub trait ToolProvider<C>: Send + Sync
246where
247    C: BaseContext + Send + Sync,
248{
249    /// Returns the provider registry name.
250    ///
251    /// This name is for engine configuration and diagnostics, not a
252    /// model-facing tool name.
253    fn name(&self) -> String;
254
255    /// Returns the current function definitions from this provider.
256    fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition>;
257
258    /// Returns the capability groups exposed by this provider.
259    ///
260    /// Each group bundles related tools (for example all tools from one MCP
261    /// server) so the discovery layer can tell the model the tools are related
262    /// and how to combine them. The default implementation returns no groups.
263    fn groups(&self) -> Vec<ToolGroup> {
264        Vec::new()
265    }
266
267    /// Returns whether this provider can currently dispatch the lowercase name.
268    fn contains_lowercase(&self, lowercase_name: &str) -> bool {
269        self.definitions(Some(&[lowercase_name.to_string()]))
270            .iter()
271            .any(|definition| definition.name.eq_ignore_ascii_case(lowercase_name))
272    }
273
274    /// Returns resource tags this provider's named tool can consume.
275    fn supported_resource_tags(&self, _name: &str) -> Vec<String> {
276        Vec::new()
277    }
278
279    /// Removes and returns resources matching the named tool.
280    fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
281        let supported_tags = self.supported_resource_tags(name);
282        select_resources(resources, &supported_tags)
283    }
284
285    /// Initializes the provider and refreshes any runtime discovery cache.
286    fn init(&self, _ctx: C) -> BoxFut<'_, Result<(), BoxError>> {
287        Box::pin(async { Ok(()) })
288    }
289
290    /// Refreshes the provider's discovery cache.
291    fn refresh(&self) -> BoxFut<'_, Result<(), BoxError>> {
292        Box::pin(async { Ok(()) })
293    }
294
295    /// Executes a provider-backed tool by model-facing name.
296    fn call(
297        &self,
298        ctx: C,
299        input: ToolInput<Json>,
300    ) -> BoxFut<'_, Result<ToolOutput<Json>, BoxError>>;
301}
302
303impl<C> dyn DynTool<C>
304where
305    C: BaseContext + Send + Sync + 'static,
306{
307    /// Returns the inner concrete tool type when it matches `T`.
308    pub fn downcast_ref<T>(&self) -> Option<&T>
309    where
310        T: Tool<C> + 'static,
311    {
312        self.as_any().downcast_ref::<T>()
313    }
314
315    /// Returns the inner concrete tool when it matches `T`.
316    pub fn downcast<T>(self: Arc<Self>) -> Result<Arc<T>, Arc<Self>>
317    where
318        T: Tool<C> + 'static,
319    {
320        match self.clone().into_any().downcast::<T>() {
321            Ok(tool) => Ok(tool),
322            Err(_) => Err(self),
323        }
324    }
325}
326
327/// Adapter that exposes a concrete [`Tool`] through [`DynTool`].
328struct ToolWrapper<T, C>(Arc<T>, PhantomData<C>)
329where
330    T: Tool<C> + 'static,
331    C: BaseContext + Send + Sync + 'static;
332
333impl<T, C> DynTool<C> for ToolWrapper<T, C>
334where
335    T: Tool<C> + 'static,
336    C: BaseContext + Send + Sync + 'static,
337{
338    fn as_any(&self) -> &(dyn Any + Send + Sync) {
339        self.0.as_ref()
340    }
341
342    fn into_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
343        self.0.clone()
344    }
345
346    fn name(&self) -> String {
347        self.0.name()
348    }
349
350    fn definition(&self) -> FunctionDefinition {
351        self.0.definition()
352    }
353
354    fn group(&self) -> Option<ToolGroupInfo> {
355        self.0.group()
356    }
357
358    fn supported_resource_tags(&self) -> Vec<String> {
359        self.0.supported_resource_tags()
360    }
361
362    fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>> {
363        let tool = self.0.clone();
364        Box::pin(async move { tool.init(ctx).await })
365    }
366
367    fn call(
368        &self,
369        ctx: C,
370        args: Json,
371        resources: Vec<Resource>,
372    ) -> BoxPinFut<Result<ToolOutput<Json>, BoxError>> {
373        let tool = self.0.clone();
374        Box::pin(async move { tool.call_raw(ctx, args, resources).await })
375    }
376}
377
378/// Name-based registry for tools.
379///
380/// # Type Parameters
381/// - `C`: The context type that implements [`BaseContext`].
382#[derive(Default)]
383pub struct ToolSet<C: BaseContext> {
384    /// Registered tools keyed by their lowercase function names.
385    pub set: BTreeMap<String, Arc<dyn DynTool<C>>>,
386}
387
388/// Registry for runtime-discovered tool providers.
389#[derive(Default)]
390pub struct ToolProviderSet<C: BaseContext> {
391    /// Registered providers keyed by provider name.
392    pub set: BTreeMap<String, Arc<dyn ToolProvider<C>>>,
393}
394
395impl<C> ToolProviderSet<C>
396where
397    C: BaseContext + Clone + Send + Sync + 'static,
398{
399    /// Creates an empty provider set.
400    pub fn new() -> Self {
401        Self {
402            set: BTreeMap::new(),
403        }
404    }
405
406    /// Returns whether a provider with the given name exists.
407    pub fn contains_provider(&self, name: &str) -> bool {
408        self.set.contains_key(&name.to_ascii_lowercase())
409    }
410
411    /// Registers a new dynamic tool provider.
412    pub fn add<T>(&mut self, provider: Arc<T>) -> Result<(), BoxError>
413    where
414        T: ToolProvider<C> + Send + Sync + 'static,
415    {
416        let name = provider.name().to_ascii_lowercase();
417        validate_function_name(&name)?;
418        if self.set.contains_key(&name) {
419            return Err(format!("tool provider {} already exists", name).into());
420        }
421
422        self.set.insert(name, provider);
423        Ok(())
424    }
425
426    /// Returns whether any provider can currently dispatch the given name.
427    pub fn contains_lowercase(&self, lowercase_name: &str) -> bool {
428        self.set
429            .values()
430            .any(|provider| provider.contains_lowercase(lowercase_name))
431    }
432
433    /// Returns dynamic function definitions for all providers or selected names.
434    pub fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition> {
435        match names {
436            Some([]) => Vec::new(),
437            _ => {
438                let mut definitions = BTreeMap::new();
439                for provider in self.set.values() {
440                    for definition in provider.definitions(names) {
441                        definitions
442                            .entry(definition.name.to_ascii_lowercase())
443                            .or_insert(definition);
444                    }
445                }
446                definitions.into_values().collect()
447            }
448        }
449    }
450
451    /// Returns the capability groups exposed by every registered provider.
452    pub fn groups(&self) -> Vec<ToolGroup> {
453        self.set
454            .values()
455            .flat_map(|provider| provider.groups())
456            .collect()
457    }
458
459    /// Returns function metadata for all provider-backed tools or selected names.
460    pub fn functions(&self, names: Option<&[String]>) -> Vec<Function> {
461        self.definitions(names)
462            .into_iter()
463            .map(|definition| {
464                let supported_resource_tags = self
465                    .set
466                    .values()
467                    .find(|provider| provider.contains_lowercase(&definition.name))
468                    .map(|provider| provider.supported_resource_tags(&definition.name))
469                    .unwrap_or_default();
470                Function {
471                    definition,
472                    supported_resource_tags,
473                }
474            })
475            .collect()
476    }
477
478    /// Removes and returns resources supported by the named provider tool.
479    pub fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
480        let lowercase_name = name.to_ascii_lowercase();
481        self.set
482            .values()
483            .find(|provider| provider.contains_lowercase(&lowercase_name))
484            .map(|provider| provider.select_resources(&lowercase_name, resources))
485            .unwrap_or_default()
486    }
487
488    /// Initializes all providers.
489    pub async fn init_all(&self, ctx: C) -> Result<(), BoxError> {
490        for provider in self.set.values() {
491            provider.init(ctx.clone()).await?;
492        }
493        Ok(())
494    }
495
496    /// Refreshes all providers.
497    pub async fn refresh_all(&self) -> Result<(), BoxError> {
498        for provider in self.set.values() {
499            provider.refresh().await?;
500        }
501        Ok(())
502    }
503
504    /// Executes a dynamic provider-backed tool.
505    pub async fn call(
506        &self,
507        ctx: C,
508        mut input: ToolInput<Json>,
509    ) -> Result<ToolOutput<Json>, BoxError> {
510        input.name.make_ascii_lowercase();
511        let provider = self
512            .set
513            .values()
514            .find(|provider| provider.contains_lowercase(&input.name))
515            .ok_or_else(|| format!("tool {} not found", input.name))?;
516        provider.call(ctx, input).await
517    }
518}
519
520impl<C> ToolSet<C>
521where
522    C: BaseContext + Send + Sync + 'static,
523{
524    /// Creates an empty tool set.
525    pub fn new() -> Self {
526        Self {
527            set: BTreeMap::new(),
528        }
529    }
530
531    /// Returns whether a tool with the given name exists.
532    pub fn contains(&self, name: &str) -> bool {
533        self.set.contains_key(&name.to_ascii_lowercase())
534    }
535
536    /// Returns whether a tool with the given lowercase name exists.
537    pub fn contains_lowercase(&self, lowercase_name: &str) -> bool {
538        self.set.contains_key(lowercase_name)
539    }
540
541    /// Returns the names of all registered tools.
542    pub fn names(&self) -> Vec<String> {
543        self.set.keys().cloned().collect()
544    }
545
546    /// Returns the capability groups assembled from registered tools.
547    ///
548    /// Tools that declare the same [`ToolGroupInfo::id`] are collected into one
549    /// [`ToolGroup`] whose `members` are exactly the registered tool names in
550    /// that group, sorted for determinism. Group metadata is taken from the
551    /// first tool (by lowercase name order) that declares the id.
552    pub fn groups(&self) -> Vec<ToolGroup> {
553        let mut grouped: BTreeMap<String, (ToolGroupInfo, Vec<String>)> = BTreeMap::new();
554        for (name, tool) in &self.set {
555            if let Some(info) = tool.group() {
556                grouped
557                    .entry(info.id.clone())
558                    .or_insert_with(|| (info, Vec::new()))
559                    .1
560                    .push(name.clone());
561            }
562        }
563
564        grouped
565            .into_values()
566            .map(|(info, mut members)| {
567                members.sort();
568                ToolGroup::from_info(info, members)
569            })
570            .collect()
571    }
572
573    /// Returns the function definition for a specific tool.
574    pub fn definition(&self, name: &str) -> Option<FunctionDefinition> {
575        self.set
576            .get(&name.to_ascii_lowercase())
577            .map(|tool| tool.definition())
578    }
579
580    /// Returns function definitions for all tools or the selected names.
581    ///
582    /// # Arguments
583    /// - `names`: Optional slice of tool names to filter by.
584    ///
585    /// # Returns
586    /// A vector of tool definitions.
587    pub fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition> {
588        match names {
589            None => self.set.values().map(|tool| tool.definition()).collect(),
590            Some(names) => names
591                .iter()
592                .filter_map(|name| {
593                    self.set
594                        .get(&name.to_ascii_lowercase())
595                        .map(|tool| tool.definition())
596                })
597                .collect(),
598        }
599    }
600
601    /// Returns function metadata for all tools or the selected names.
602    ///
603    /// # Arguments
604    /// - `names`: Optional slice of tool names to filter by.
605    ///
606    /// # Returns
607    /// A vector of tool function metadata.
608    pub fn functions(&self, names: Option<&[String]>) -> Vec<Function> {
609        match names {
610            None => self
611                .set
612                .values()
613                .map(|tool| Function {
614                    definition: tool.definition(),
615                    supported_resource_tags: tool.supported_resource_tags(),
616                })
617                .collect(),
618            Some(names) => names
619                .iter()
620                .filter_map(|name| {
621                    self.set
622                        .get(&name.to_ascii_lowercase())
623                        .map(|tool| Function {
624                            definition: tool.definition(),
625                            supported_resource_tags: tool.supported_resource_tags(),
626                        })
627                })
628                .collect(),
629        }
630    }
631
632    /// Removes and returns resources supported by the named tool.
633    pub fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
634        self.set
635            .get(&name.to_ascii_lowercase())
636            .map(|tool| {
637                let supported_tags = tool.supported_resource_tags();
638                select_resources(resources, &supported_tags)
639            })
640            .unwrap_or_default()
641    }
642
643    /// Registers a new tool.
644    ///
645    /// # Arguments
646    /// - `tool`: The tool to register.
647    pub fn add<T>(&mut self, tool: Arc<T>) -> Result<(), BoxError>
648    where
649        T: Tool<C> + Send + Sync + 'static,
650    {
651        let name = tool.name().to_ascii_lowercase();
652        validate_function_name(&name)?;
653        if self.set.contains_key(&name) {
654            return Err(format!("tool {} already exists", name).into());
655        }
656
657        let tool_dyn = ToolWrapper(tool, PhantomData);
658        self.set.insert(name, Arc::new(tool_dyn));
659        Ok(())
660    }
661
662    /// Returns a tool by name.
663    pub fn get(&self, name: &str) -> Option<Arc<dyn DynTool<C>>> {
664        self.set.get(&name.to_ascii_lowercase()).cloned()
665    }
666
667    /// Returns a tool by lowercase name.
668    pub fn get_lowercase(&self, lowercase_name: &str) -> Option<Arc<dyn DynTool<C>>> {
669        self.set.get(lowercase_name).cloned()
670    }
671}
672
673#[cfg(test)]
674mod tests {
675    use super::*;
676    use candid::{CandidType, Principal, utils::ArgumentEncoder};
677    use serde_json::json;
678    use std::{sync::Arc, time::Duration};
679
680    use crate::{
681        BaseContext, CacheExpiry, CacheFeatures, CancellationToken, CanisterCaller, HttpFeatures,
682        KeysFeatures, ObjectMeta, Path, PutMode, PutResult, RequestMeta, StateFeatures,
683        StoreFeatures, ToolInput,
684    };
685
686    #[derive(Clone)]
687    struct TestContext {
688        engine_id: Principal,
689        caller: Principal,
690        meta: RequestMeta,
691        cancellation_token: CancellationToken,
692    }
693
694    impl Default for TestContext {
695        fn default() -> Self {
696            Self {
697                engine_id: Principal::management_canister(),
698                caller: Principal::anonymous(),
699                meta: RequestMeta::default(),
700                cancellation_token: CancellationToken::new(),
701            }
702        }
703    }
704
705    impl StateFeatures for TestContext {
706        fn engine_id(&self) -> &Principal {
707            &self.engine_id
708        }
709
710        fn engine_name(&self) -> &str {
711            "test-engine"
712        }
713
714        fn caller(&self) -> &Principal {
715            &self.caller
716        }
717
718        fn meta(&self) -> &RequestMeta {
719            &self.meta
720        }
721
722        fn cancellation_token(&self) -> CancellationToken {
723            self.cancellation_token.clone()
724        }
725
726        fn time_elapsed(&self) -> Duration {
727            Duration::ZERO
728        }
729    }
730
731    impl KeysFeatures for TestContext {
732        async fn a256gcm_key(&self, _derivation_path: Vec<Vec<u8>>) -> Result<[u8; 32], BoxError> {
733            Ok([0; 32])
734        }
735
736        async fn ed25519_sign_message(
737            &self,
738            _derivation_path: Vec<Vec<u8>>,
739            _message: &[u8],
740        ) -> Result<[u8; 64], BoxError> {
741            Ok([0; 64])
742        }
743
744        async fn ed25519_verify(
745            &self,
746            _derivation_path: Vec<Vec<u8>>,
747            _message: &[u8],
748            _signature: &[u8],
749        ) -> Result<(), BoxError> {
750            Ok(())
751        }
752
753        async fn ed25519_public_key(
754            &self,
755            _derivation_path: Vec<Vec<u8>>,
756        ) -> Result<[u8; 32], BoxError> {
757            Ok([0; 32])
758        }
759
760        async fn secp256k1_sign_message_bip340(
761            &self,
762            _derivation_path: Vec<Vec<u8>>,
763            _message: &[u8],
764        ) -> Result<[u8; 64], BoxError> {
765            Ok([0; 64])
766        }
767
768        async fn secp256k1_verify_bip340(
769            &self,
770            _derivation_path: Vec<Vec<u8>>,
771            _message: &[u8],
772            _signature: &[u8],
773        ) -> Result<(), BoxError> {
774            Ok(())
775        }
776
777        async fn secp256k1_sign_message_ecdsa(
778            &self,
779            _derivation_path: Vec<Vec<u8>>,
780            _message: &[u8],
781        ) -> Result<[u8; 64], BoxError> {
782            Ok([0; 64])
783        }
784
785        async fn secp256k1_sign_digest_ecdsa(
786            &self,
787            _derivation_path: Vec<Vec<u8>>,
788            _message_hash: &[u8],
789        ) -> Result<[u8; 64], BoxError> {
790            Ok([0; 64])
791        }
792
793        async fn secp256k1_verify_ecdsa(
794            &self,
795            _derivation_path: Vec<Vec<u8>>,
796            _message_hash: &[u8],
797            _signature: &[u8],
798        ) -> Result<(), BoxError> {
799            Ok(())
800        }
801
802        async fn secp256k1_public_key(
803            &self,
804            _derivation_path: Vec<Vec<u8>>,
805        ) -> Result<[u8; 33], BoxError> {
806            Ok([0; 33])
807        }
808    }
809
810    impl StoreFeatures for TestContext {
811        async fn store_get(&self, _path: &Path) -> Result<(bytes::Bytes, ObjectMeta), BoxError> {
812            Err("not implemented".into())
813        }
814
815        async fn store_list(
816            &self,
817            _prefix: Option<&Path>,
818            _offset: &Path,
819        ) -> Result<Vec<ObjectMeta>, BoxError> {
820            Ok(Vec::new())
821        }
822
823        async fn store_put(
824            &self,
825            _path: &Path,
826            _mode: PutMode,
827            _value: bytes::Bytes,
828        ) -> Result<PutResult, BoxError> {
829            Err("not implemented".into())
830        }
831
832        async fn store_rename_if_not_exists(
833            &self,
834            _from: &Path,
835            _to: &Path,
836        ) -> Result<(), BoxError> {
837            Err("not implemented".into())
838        }
839
840        async fn store_delete(&self, _path: &Path) -> Result<(), BoxError> {
841            Ok(())
842        }
843    }
844
845    impl CacheFeatures for TestContext {
846        fn cache_contains(&self, _key: &str) -> bool {
847            false
848        }
849
850        async fn cache_get<T>(&self, _key: &str) -> Result<T, BoxError>
851        where
852            T: DeserializeOwned,
853        {
854            Err("not implemented".into())
855        }
856
857        async fn cache_get_with<T, F>(&self, _key: &str, _init: F) -> Result<T, BoxError>
858        where
859            T: Sized + DeserializeOwned + Serialize + Send,
860            F: Future<Output = Result<(T, Option<CacheExpiry>), BoxError>> + Send + 'static,
861        {
862            Err("not implemented".into())
863        }
864
865        async fn cache_set<T>(&self, _key: &str, _val: (T, Option<CacheExpiry>))
866        where
867            T: Sized + Serialize + Send,
868        {
869        }
870
871        async fn cache_set_if_not_exists<T>(
872            &self,
873            _key: &str,
874            _val: (T, Option<CacheExpiry>),
875        ) -> bool
876        where
877            T: Sized + Serialize + Send,
878        {
879            false
880        }
881
882        async fn cache_delete(&self, _key: &str) -> bool {
883            false
884        }
885
886        fn cache_raw_iter(
887            &self,
888        ) -> impl Iterator<Item = (Arc<String>, Arc<(bytes::Bytes, Option<CacheExpiry>)>)> {
889            std::iter::empty()
890        }
891    }
892
893    impl HttpFeatures for TestContext {
894        async fn https_call(
895            &self,
896            _url: &str,
897            _method: http::Method,
898            _headers: Option<http::HeaderMap>,
899            _body: Option<Vec<u8>>,
900        ) -> Result<reqwest::Response, BoxError> {
901            Err("not implemented".into())
902        }
903
904        async fn https_signed_call(
905            &self,
906            _url: &str,
907            _method: http::Method,
908            _message_digest: [u8; 32],
909            _headers: Option<http::HeaderMap>,
910            _body: Option<Vec<u8>>,
911        ) -> Result<reqwest::Response, BoxError> {
912            Err("not implemented".into())
913        }
914
915        async fn https_signed_rpc<T>(
916            &self,
917            _endpoint: &str,
918            _method: &str,
919            _args: impl Serialize + Send,
920        ) -> Result<T, BoxError>
921        where
922            T: DeserializeOwned,
923        {
924            Err("not implemented".into())
925        }
926    }
927
928    impl crate::CanisterCaller for TestContext {
929        async fn canister_query<In, Out>(
930            &self,
931            _canister: &Principal,
932            _method: &str,
933            _args: In,
934        ) -> Result<Out, BoxError>
935        where
936            In: ArgumentEncoder + Send,
937            Out: CandidType + for<'a> candid::Deserialize<'a>,
938        {
939            Err("not implemented".into())
940        }
941
942        async fn canister_update<In, Out>(
943            &self,
944            _canister: &Principal,
945            _method: &str,
946            _args: In,
947        ) -> Result<Out, BoxError>
948        where
949            In: ArgumentEncoder + Send,
950            Out: CandidType + for<'a> candid::Deserialize<'a>,
951        {
952            Err("not implemented".into())
953        }
954    }
955
956    impl BaseContext for TestContext {
957        async fn remote_tool_call(
958            &self,
959            _endpoint: &str,
960            _args: ToolInput<Json>,
961        ) -> Result<ToolOutput<Json>, BoxError> {
962            Err("not implemented".into())
963        }
964    }
965
966    struct ExampleTool {
967        id: usize,
968    }
969
970    struct OtherTool;
971
972    #[derive(serde::Deserialize)]
973    struct EchoArgs {
974        value: String,
975        fail: bool,
976    }
977
978    struct TaggedTool;
979
980    struct InvalidTool;
981
982    fn resource(id: u64, tags: &[&str]) -> Resource {
983        Resource {
984            _id: id,
985            name: format!("resource-{id}"),
986            tags: tags.iter().map(|tag| tag.to_string()).collect(),
987            ..Default::default()
988        }
989    }
990
991    impl Tool<TestContext> for ExampleTool {
992        type Args = ();
993        type Output = String;
994
995        fn name(&self) -> String {
996            "example_tool".to_string()
997        }
998
999        fn description(&self) -> String {
1000            "Example tool used for downcast tests".to_string()
1001        }
1002
1003        fn definition(&self) -> FunctionDefinition {
1004            FunctionDefinition {
1005                name: self.name(),
1006                description: self.description(),
1007                parameters: json!({
1008                    "type": "object",
1009                    "properties": {},
1010                    "required": [],
1011                    "additionalProperties": false
1012                }),
1013                strict: Some(true),
1014            }
1015        }
1016
1017        async fn call(
1018            &self,
1019            _ctx: TestContext,
1020            _args: Self::Args,
1021            _resources: Vec<Resource>,
1022        ) -> Result<ToolOutput<Self::Output>, BoxError> {
1023            Ok(ToolOutput::new(self.id.to_string()))
1024        }
1025    }
1026
1027    impl Tool<TestContext> for OtherTool {
1028        type Args = ();
1029        type Output = String;
1030
1031        fn name(&self) -> String {
1032            "other_tool".to_string()
1033        }
1034
1035        fn description(&self) -> String {
1036            "Other tool used for downcast tests".to_string()
1037        }
1038
1039        fn definition(&self) -> FunctionDefinition {
1040            FunctionDefinition {
1041                name: self.name(),
1042                description: self.description(),
1043                parameters: json!({
1044                    "type": "object",
1045                    "properties": {},
1046                    "required": [],
1047                    "additionalProperties": false
1048                }),
1049                strict: Some(true),
1050            }
1051        }
1052
1053        async fn call(
1054            &self,
1055            _ctx: TestContext,
1056            _args: Self::Args,
1057            _resources: Vec<Resource>,
1058        ) -> Result<ToolOutput<Self::Output>, BoxError> {
1059            Ok(ToolOutput::new("other".to_string()))
1060        }
1061    }
1062
1063    impl Tool<TestContext> for TaggedTool {
1064        type Args = EchoArgs;
1065        type Output = Json;
1066
1067        fn name(&self) -> String {
1068            "tagged_tool".to_string()
1069        }
1070
1071        fn description(&self) -> String {
1072            "Tool that consumes text and code resources".to_string()
1073        }
1074
1075        fn definition(&self) -> FunctionDefinition {
1076            FunctionDefinition {
1077                name: self.name(),
1078                description: self.description(),
1079                parameters: json!({
1080                    "type": "object",
1081                    "properties": {
1082                        "value": {"type": "string"},
1083                        "fail": {"type": "boolean"}
1084                    },
1085                    "required": ["value", "fail"],
1086                    "additionalProperties": false
1087                }),
1088                strict: Some(true),
1089            }
1090        }
1091
1092        fn supported_resource_tags(&self) -> Vec<String> {
1093            vec!["text".to_string(), "code".to_string()]
1094        }
1095
1096        async fn call(
1097            &self,
1098            _ctx: TestContext,
1099            args: Self::Args,
1100            resources: Vec<Resource>,
1101        ) -> Result<ToolOutput<Self::Output>, BoxError> {
1102            if args.fail {
1103                return Err("forced failure".into());
1104            }
1105
1106            let mut output = ToolOutput::new(json!({
1107                "value": args.value,
1108                "resources": resources.len(),
1109            }));
1110            output.is_error = Some(false);
1111            Ok(output)
1112        }
1113    }
1114
1115    impl Tool<TestContext> for InvalidTool {
1116        type Args = ();
1117        type Output = String;
1118
1119        fn name(&self) -> String {
1120            "bad-tool".to_string()
1121        }
1122
1123        fn description(&self) -> String {
1124            "Invalid function name".to_string()
1125        }
1126
1127        fn definition(&self) -> FunctionDefinition {
1128            FunctionDefinition {
1129                name: self.name(),
1130                description: self.description(),
1131                parameters: json!({"type": "object"}),
1132                strict: Some(true),
1133            }
1134        }
1135
1136        async fn call(
1137            &self,
1138            _ctx: TestContext,
1139            _args: Self::Args,
1140            _resources: Vec<Resource>,
1141        ) -> Result<ToolOutput<Self::Output>, BoxError> {
1142            Ok(ToolOutput::new(String::new()))
1143        }
1144    }
1145
1146    #[test]
1147    fn dyn_tool_downcast_ref_returns_inner_tool() {
1148        let tool = Arc::new(ExampleTool { id: 7 });
1149        let mut tool_set = ToolSet::<TestContext>::new();
1150        tool_set.add(tool).unwrap();
1151
1152        let dyn_tool = tool_set.get("example_tool").unwrap();
1153        let concrete = dyn_tool.downcast_ref::<ExampleTool>().unwrap();
1154
1155        assert_eq!(concrete.id, 7);
1156        assert!(dyn_tool.downcast_ref::<OtherTool>().is_none());
1157    }
1158
1159    #[test]
1160    fn dyn_tool_downcast_returns_original_arc() {
1161        let tool = Arc::new(ExampleTool { id: 9 });
1162        let mut tool_set = ToolSet::<TestContext>::new();
1163        tool_set.add(tool.clone()).unwrap();
1164
1165        let dyn_tool = tool_set.get("example_tool").unwrap();
1166        let concrete = match dyn_tool.downcast::<ExampleTool>() {
1167            Ok(tool) => tool,
1168            Err(_) => panic!("expected downcast to ExampleTool to succeed"),
1169        };
1170
1171        assert_eq!(concrete.id, 9);
1172        assert!(Arc::ptr_eq(&concrete, &tool));
1173    }
1174
1175    #[test]
1176    fn dyn_tool_downcast_mismatch_returns_original_arc() {
1177        let tool = Arc::new(ExampleTool { id: 11 });
1178        let mut tool_set = ToolSet::<TestContext>::new();
1179        tool_set.add(tool).unwrap();
1180
1181        let dyn_tool = tool_set.get("example_tool").unwrap();
1182        let original = dyn_tool.clone();
1183        let err = match dyn_tool.downcast::<OtherTool>() {
1184            Ok(_) => panic!("expected downcast to OtherTool to fail"),
1185            Err(err) => err,
1186        };
1187
1188        assert!(Arc::ptr_eq(&err, &original));
1189        assert_eq!(err.name(), "example_tool");
1190    }
1191
1192    #[test]
1193    fn fixture_tools_cover_direct_methods() {
1194        futures::executor::block_on(async {
1195            let other = OtherTool;
1196            assert_eq!(other.name(), "other_tool");
1197            assert_eq!(other.description(), "Other tool used for downcast tests");
1198            let definition = other.definition();
1199            assert_eq!(definition.name, "other_tool");
1200            assert_eq!(definition.description, "Other tool used for downcast tests");
1201            assert_eq!(definition.parameters["type"], "object");
1202            let output = other
1203                .call(TestContext::default(), (), Vec::new())
1204                .await
1205                .unwrap();
1206            assert_eq!(output.output, "other");
1207
1208            let invalid = InvalidTool;
1209            assert_eq!(invalid.name(), "bad-tool");
1210            assert_eq!(invalid.description(), "Invalid function name");
1211            let definition = invalid.definition();
1212            assert_eq!(definition.name, "bad-tool");
1213            assert_eq!(definition.description, "Invalid function name");
1214            assert_eq!(definition.parameters["type"], "object");
1215            let output = invalid
1216                .call(TestContext::default(), (), Vec::new())
1217                .await
1218                .unwrap();
1219            assert!(output.output.is_empty());
1220        });
1221    }
1222
1223    #[test]
1224    fn tool_default_methods_call_raw_and_dyn_wrapper_forward_calls() {
1225        futures::executor::block_on(async {
1226            let tool = Arc::new(ExampleTool { id: 42 });
1227            let mut resources = vec![resource(1, &["text"])];
1228
1229            assert!(tool.supported_resource_tags().is_empty());
1230            assert!(tool.select_resources(&mut resources).is_empty());
1231            assert_eq!(resources.len(), 1);
1232            tool.init(TestContext::default()).await.unwrap();
1233
1234            let raw = tool
1235                .call_raw(TestContext::default(), Json::Null, Vec::new())
1236                .await
1237                .unwrap();
1238            assert_eq!(raw.output, json!("42"));
1239            assert_eq!(raw.usage.requests, 1);
1240
1241            let invalid = tool
1242                .call_raw(TestContext::default(), json!({"bad": true}), Vec::new())
1243                .await
1244                .unwrap_err();
1245            assert!(invalid.to_string().contains("invalid args"));
1246
1247            let mut tool_set = ToolSet::<TestContext>::new();
1248            tool_set.add(tool).unwrap();
1249            let dyn_tool = tool_set.get("EXAMPLE_TOOL").unwrap();
1250
1251            assert_eq!(dyn_tool.name(), "example_tool");
1252            assert_eq!(dyn_tool.definition().name, "example_tool");
1253            assert!(dyn_tool.supported_resource_tags().is_empty());
1254            dyn_tool.init(TestContext::default()).await.unwrap();
1255
1256            let output = dyn_tool
1257                .call(TestContext::default(), Json::Null, Vec::new())
1258                .await
1259                .unwrap();
1260            assert_eq!(output.output, json!("42"));
1261            assert_eq!(output.usage.requests, 1);
1262        });
1263    }
1264
1265    #[test]
1266    fn tool_set_registry_filters_resources_and_reports_errors() {
1267        futures::executor::block_on(async {
1268            let mut tool_set = ToolSet::<TestContext>::new();
1269            tool_set.add(Arc::new(ExampleTool { id: 1 })).unwrap();
1270            tool_set.add(Arc::new(TaggedTool)).unwrap();
1271
1272            assert!(tool_set.contains("EXAMPLE_TOOL"));
1273            assert!(tool_set.contains_lowercase("tagged_tool"));
1274            assert!(!tool_set.contains("missing_tool"));
1275            assert_eq!(
1276                tool_set.names(),
1277                vec!["example_tool".to_string(), "tagged_tool".to_string()]
1278            );
1279
1280            let definition = tool_set.definition("TAGGED_TOOL").unwrap();
1281            assert_eq!(definition.name, "tagged_tool");
1282            assert!(tool_set.definition("missing_tool").is_none());
1283
1284            let selected_names = vec!["TAGGED_TOOL".to_string(), "missing_tool".to_string()];
1285            let selected_definitions = tool_set.definitions(Some(&selected_names));
1286            assert_eq!(selected_definitions.len(), 1);
1287            assert_eq!(selected_definitions[0].name, "tagged_tool");
1288            assert_eq!(tool_set.definitions(None).len(), 2);
1289
1290            let selected_functions = tool_set.functions(Some(&selected_names));
1291            assert_eq!(selected_functions.len(), 1);
1292            assert_eq!(
1293                selected_functions[0].supported_resource_tags,
1294                vec!["text".to_string(), "code".to_string()]
1295            );
1296            assert_eq!(tool_set.functions(None).len(), 2);
1297
1298            let mut resources = vec![
1299                resource(1, &["image"]),
1300                resource(2, &["text"]),
1301                resource(3, &["code", "text"]),
1302                resource(4, &["audio"]),
1303            ];
1304            let selected = tool_set.select_resources("TAGGED_TOOL", &mut resources);
1305            assert_eq!(
1306                selected
1307                    .iter()
1308                    .map(|resource| resource._id)
1309                    .collect::<Vec<_>>(),
1310                vec![2, 3]
1311            );
1312            assert_eq!(
1313                resources
1314                    .iter()
1315                    .map(|resource| resource._id)
1316                    .collect::<Vec<_>>(),
1317                vec![1, 4]
1318            );
1319            assert!(
1320                tool_set
1321                    .select_resources("missing_tool", &mut resources)
1322                    .is_empty()
1323            );
1324
1325            let dyn_tool = tool_set.get_lowercase("tagged_tool").unwrap();
1326            let output = dyn_tool
1327                .call(
1328                    TestContext::default(),
1329                    json!({"value": "ok", "fail": false}),
1330                    vec![resource(9, &["text"])],
1331                )
1332                .await
1333                .unwrap();
1334            assert_eq!(output.output["value"], "ok");
1335            assert_eq!(output.output["resources"], 1);
1336            assert_eq!(output.is_error, Some(false));
1337            assert_eq!(output.usage.requests, 1);
1338            assert!(tool_set.get("missing_tool").is_none());
1339            assert!(tool_set.get_lowercase("missing_tool").is_none());
1340
1341            let failed = dyn_tool
1342                .call(
1343                    TestContext::default(),
1344                    json!({"value": "bad", "fail": true}),
1345                    Vec::new(),
1346                )
1347                .await
1348                .unwrap_err();
1349            assert!(failed.to_string().contains("call failed"));
1350
1351            let duplicate = tool_set.add(Arc::new(ExampleTool { id: 2 })).unwrap_err();
1352            assert!(duplicate.to_string().contains("already exists"));
1353
1354            let invalid = tool_set.add(Arc::new(InvalidTool)).unwrap_err();
1355            assert!(invalid.to_string().contains("invalid character"));
1356        });
1357    }
1358
1359    #[test]
1360    fn test_tool_context_mock_features_cover_default_paths() {
1361        futures::executor::block_on(async {
1362            let ctx = TestContext::default();
1363            assert_eq!(*ctx.engine_id(), Principal::management_canister());
1364            assert_eq!(ctx.engine_name(), "test-engine");
1365            assert_eq!(*ctx.caller(), Principal::anonymous());
1366            assert!(ctx.meta().user.is_none());
1367            assert!(!ctx.cancellation_token().is_cancelled());
1368            assert_eq!(ctx.time_elapsed(), Duration::ZERO);
1369
1370            assert_eq!(ctx.a256gcm_key(Vec::new()).await.unwrap(), [0; 32]);
1371            assert_eq!(
1372                ctx.ed25519_sign_message(Vec::new(), b"message")
1373                    .await
1374                    .unwrap(),
1375                [0; 64]
1376            );
1377            ctx.ed25519_verify(Vec::new(), b"message", &[0; 64])
1378                .await
1379                .unwrap();
1380            assert_eq!(ctx.ed25519_public_key(Vec::new()).await.unwrap(), [0; 32]);
1381            assert_eq!(
1382                ctx.secp256k1_sign_message_bip340(Vec::new(), b"message")
1383                    .await
1384                    .unwrap(),
1385                [0; 64]
1386            );
1387            ctx.secp256k1_verify_bip340(Vec::new(), b"message", &[0; 64])
1388                .await
1389                .unwrap();
1390            assert_eq!(
1391                ctx.secp256k1_sign_message_ecdsa(Vec::new(), b"message")
1392                    .await
1393                    .unwrap(),
1394                [0; 64]
1395            );
1396            assert_eq!(
1397                ctx.secp256k1_sign_digest_ecdsa(Vec::new(), &[0; 32])
1398                    .await
1399                    .unwrap(),
1400                [0; 64]
1401            );
1402            ctx.secp256k1_verify_ecdsa(Vec::new(), &[0; 32], &[0; 64])
1403                .await
1404                .unwrap();
1405            assert_eq!(ctx.secp256k1_public_key(Vec::new()).await.unwrap(), [0; 33]);
1406
1407            assert!(ctx.store_get(&Path::from("missing")).await.is_err());
1408            assert!(
1409                ctx.store_list(None, &Path::default())
1410                    .await
1411                    .unwrap()
1412                    .is_empty()
1413            );
1414            assert!(
1415                ctx.store_put(&Path::from("file"), PutMode::Overwrite, bytes::Bytes::new())
1416                    .await
1417                    .is_err()
1418            );
1419            assert!(
1420                ctx.store_rename_if_not_exists(&Path::from("a"), &Path::from("b"))
1421                    .await
1422                    .is_err()
1423            );
1424            ctx.store_delete(&Path::from("file")).await.unwrap();
1425
1426            assert!(!ctx.cache_contains("key"));
1427            assert!(ctx.cache_get::<String>("key").await.is_err());
1428            assert!(
1429                ctx.cache_get_with("key", async { Ok(("value".to_string(), None)) })
1430                    .await
1431                    .is_err()
1432            );
1433            ctx.cache_set("key", ("value".to_string(), None)).await;
1434            assert!(
1435                !ctx.cache_set_if_not_exists("key", ("value".to_string(), None))
1436                    .await
1437            );
1438            assert!(!ctx.cache_delete("key").await);
1439            assert_eq!(ctx.cache_raw_iter().count(), 0);
1440
1441            assert!(
1442                ctx.https_call("https://example.test", http::Method::GET, None, None)
1443                    .await
1444                    .is_err()
1445            );
1446            assert!(
1447                ctx.https_signed_call(
1448                    "https://example.test",
1449                    http::Method::POST,
1450                    [0; 32],
1451                    None,
1452                    None,
1453                )
1454                .await
1455                .is_err()
1456            );
1457            let rpc: Result<String, BoxError> = ctx
1458                .https_signed_rpc("https://example.test", "method", &())
1459                .await;
1460            assert!(rpc.is_err());
1461
1462            let query: Result<String, BoxError> = ctx
1463                .canister_query(&Principal::anonymous(), "query", ())
1464                .await;
1465            assert!(query.is_err());
1466            let update: Result<String, BoxError> = ctx
1467                .canister_update(&Principal::anonymous(), "update", ())
1468                .await;
1469            assert!(update.is_err());
1470
1471            assert!(
1472                ctx.remote_tool_call(
1473                    "https://example.test",
1474                    ToolInput::new("tool".to_string(), Json::Null),
1475                )
1476                .await
1477                .is_err()
1478            );
1479        });
1480    }
1481
1482    struct GroupedTool {
1483        name: &'static str,
1484        group: &'static str,
1485    }
1486
1487    impl Tool<TestContext> for GroupedTool {
1488        type Args = ();
1489        type Output = String;
1490
1491        fn name(&self) -> String {
1492            self.name.to_string()
1493        }
1494
1495        fn description(&self) -> String {
1496            "Grouped tool fixture".to_string()
1497        }
1498
1499        fn definition(&self) -> FunctionDefinition {
1500            FunctionDefinition {
1501                name: self.name(),
1502                description: self.description(),
1503                parameters: json!({
1504                    "type": "object",
1505                    "properties": {},
1506                    "required": [],
1507                    "additionalProperties": false
1508                }),
1509                strict: Some(true),
1510            }
1511        }
1512
1513        fn group(&self) -> Option<ToolGroupInfo> {
1514            Some(ToolGroupInfo {
1515                id: self.group.to_string(),
1516                title: format!("{} title", self.group),
1517                description: format!("{} description", self.group),
1518                instructions: Some(format!("{} instructions", self.group)),
1519            })
1520        }
1521
1522        async fn call(
1523            &self,
1524            _ctx: TestContext,
1525            _args: Self::Args,
1526            _resources: Vec<Resource>,
1527        ) -> Result<ToolOutput<Self::Output>, BoxError> {
1528            Ok(ToolOutput::new(String::new()))
1529        }
1530    }
1531
1532    #[test]
1533    fn tool_set_groups_aggregate_members_by_id() {
1534        let mut tool_set = ToolSet::<TestContext>::new();
1535        tool_set
1536            .add(Arc::new(GroupedTool {
1537                name: "fs_write",
1538                group: "fs",
1539            }))
1540            .unwrap();
1541        tool_set
1542            .add(Arc::new(GroupedTool {
1543                name: "fs_read",
1544                group: "fs",
1545            }))
1546            .unwrap();
1547        tool_set
1548            .add(Arc::new(GroupedTool {
1549                name: "mem_get",
1550                group: "memory",
1551            }))
1552            .unwrap();
1553        // A tool with no group declaration is excluded from every group.
1554        tool_set.add(Arc::new(ExampleTool { id: 1 })).unwrap();
1555
1556        let groups = tool_set.groups();
1557        assert_eq!(groups.len(), 2);
1558
1559        let fs = groups.iter().find(|group| group.id == "fs").unwrap();
1560        // Members reflect the registered tools, sorted for determinism.
1561        assert_eq!(
1562            fs.members,
1563            vec!["fs_read".to_string(), "fs_write".to_string()]
1564        );
1565        assert_eq!(fs.title, "fs title");
1566        assert_eq!(fs.instructions.as_deref(), Some("fs instructions"));
1567
1568        let memory = groups.iter().find(|group| group.id == "memory").unwrap();
1569        assert_eq!(memory.members, vec!["mem_get".to_string()]);
1570    }
1571}