1use serde::{Deserialize, Serialize, de::DeserializeOwned};
14use std::{any::Any, collections::BTreeMap, future::Future, marker::PhantomData, sync::Arc};
15
16use crate::{
17 BoxError, BoxFut, BoxPinFut, Function, Json, Resource, ToolInput, ToolOutput,
18 context::BaseContext, model::FunctionDefinition, select_resources, validate_function_name,
19};
20
21pub trait Tool<C>: Send + Sync
26where
27 C: BaseContext + Send + Sync,
28{
29 type Args: DeserializeOwned + Send;
31
32 type Output: Serialize;
34
35 fn name(&self) -> String;
44
45 fn description(&self) -> String;
47
48 fn definition(&self) -> FunctionDefinition;
53
54 fn group(&self) -> Option<ToolGroupInfo> {
60 None
61 }
62
63 fn supported_resource_tags(&self) -> Vec<String> {
72 Vec::new()
73 }
74
75 fn select_resources(&self, resources: &mut Vec<Resource>) -> Vec<Resource> {
77 let supported_tags = self.supported_resource_tags();
78 select_resources(resources, &supported_tags)
79 }
80
81 fn init(&self, _ctx: C) -> impl Future<Output = Result<(), BoxError>> + Send {
85 futures::future::ready(Ok(()))
86 }
87
88 fn call(
98 &self,
99 ctx: C,
100 args: Self::Args,
101 resources: Vec<Resource>,
102 ) -> impl Future<Output = Result<ToolOutput<Self::Output>, BoxError>> + Send;
103
104 fn call_raw(
106 &self,
107 ctx: C,
108 args: Json,
109 resources: Vec<Resource>,
110 ) -> impl Future<Output = Result<ToolOutput<Json>, BoxError>> + Send {
111 async move {
112 let args: Self::Args = serde_json::from_value(args)
113 .map_err(|err| format!("tool {}, invalid args: {}", self.name(), err))?;
114 let mut result = self
115 .call(ctx, args, resources)
116 .await
117 .map_err(|err| format!("tool {}, call failed: {}", self.name(), err))?;
118 let output = serde_json::to_value(&result.output)?;
119 if result.usage.requests == 0 {
120 result.usage.requests = 1;
121 }
122
123 Ok(ToolOutput {
124 output,
125 is_error: result.is_error,
126 artifacts: result.artifacts,
127 usage: result.usage,
128 tools_usage: result.tools_usage,
129 })
130 }
131 }
132}
133
134pub trait DynTool<C>: Send + Sync
139where
140 C: BaseContext + Send + Sync,
141{
142 fn as_any(&self) -> &(dyn Any + Send + Sync);
144
145 fn into_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
147
148 fn name(&self) -> String;
150
151 fn definition(&self) -> FunctionDefinition;
153
154 fn group(&self) -> Option<ToolGroupInfo> {
156 None
157 }
158
159 fn supported_resource_tags(&self) -> Vec<String>;
161
162 fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>>;
164
165 fn call(
167 &self,
168 ctx: C,
169 args: Json,
170 resources: Vec<Resource>,
171 ) -> BoxPinFut<Result<ToolOutput<Json>, BoxError>>;
172}
173
174#[derive(Debug, Clone, Default, Serialize, Deserialize)]
184pub struct ToolGroupInfo {
185 pub id: String,
187 pub title: String,
189 pub description: String,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
194 pub instructions: Option<String>,
195}
196
197#[derive(Debug, Clone, Default, Serialize, Deserialize)]
211pub struct ToolGroup {
212 pub id: String,
214 pub title: String,
216 pub description: String,
218 #[serde(default, skip_serializing_if = "Option::is_none")]
221 pub instructions: Option<String>,
222 pub members: Vec<String>,
224}
225
226impl ToolGroup {
227 pub fn from_info(info: ToolGroupInfo, members: Vec<String>) -> Self {
229 Self {
230 id: info.id,
231 title: info.title,
232 description: info.description,
233 instructions: info.instructions,
234 members,
235 }
236 }
237}
238
239pub trait ToolProvider<C>: Send + Sync
246where
247 C: BaseContext + Send + Sync,
248{
249 fn name(&self) -> String;
254
255 fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition>;
257
258 fn groups(&self) -> Vec<ToolGroup> {
264 Vec::new()
265 }
266
267 fn contains_lowercase(&self, lowercase_name: &str) -> bool {
269 self.definitions(Some(&[lowercase_name.to_string()]))
270 .iter()
271 .any(|definition| definition.name.eq_ignore_ascii_case(lowercase_name))
272 }
273
274 fn supported_resource_tags(&self, _name: &str) -> Vec<String> {
276 Vec::new()
277 }
278
279 fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
281 let supported_tags = self.supported_resource_tags(name);
282 select_resources(resources, &supported_tags)
283 }
284
285 fn init(&self, _ctx: C) -> BoxFut<'_, Result<(), BoxError>> {
287 Box::pin(async { Ok(()) })
288 }
289
290 fn refresh(&self) -> BoxFut<'_, Result<(), BoxError>> {
292 Box::pin(async { Ok(()) })
293 }
294
295 fn call(
297 &self,
298 ctx: C,
299 input: ToolInput<Json>,
300 ) -> BoxFut<'_, Result<ToolOutput<Json>, BoxError>>;
301}
302
303impl<C> dyn DynTool<C>
304where
305 C: BaseContext + Send + Sync + 'static,
306{
307 pub fn downcast_ref<T>(&self) -> Option<&T>
309 where
310 T: Tool<C> + 'static,
311 {
312 self.as_any().downcast_ref::<T>()
313 }
314
315 pub fn downcast<T>(self: Arc<Self>) -> Result<Arc<T>, Arc<Self>>
317 where
318 T: Tool<C> + 'static,
319 {
320 match self.clone().into_any().downcast::<T>() {
321 Ok(tool) => Ok(tool),
322 Err(_) => Err(self),
323 }
324 }
325}
326
327struct ToolWrapper<T, C>(Arc<T>, PhantomData<C>)
329where
330 T: Tool<C> + 'static,
331 C: BaseContext + Send + Sync + 'static;
332
333impl<T, C> DynTool<C> for ToolWrapper<T, C>
334where
335 T: Tool<C> + 'static,
336 C: BaseContext + Send + Sync + 'static,
337{
338 fn as_any(&self) -> &(dyn Any + Send + Sync) {
339 self.0.as_ref()
340 }
341
342 fn into_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
343 self.0.clone()
344 }
345
346 fn name(&self) -> String {
347 self.0.name()
348 }
349
350 fn definition(&self) -> FunctionDefinition {
351 self.0.definition()
352 }
353
354 fn group(&self) -> Option<ToolGroupInfo> {
355 self.0.group()
356 }
357
358 fn supported_resource_tags(&self) -> Vec<String> {
359 self.0.supported_resource_tags()
360 }
361
362 fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>> {
363 let tool = self.0.clone();
364 Box::pin(async move { tool.init(ctx).await })
365 }
366
367 fn call(
368 &self,
369 ctx: C,
370 args: Json,
371 resources: Vec<Resource>,
372 ) -> BoxPinFut<Result<ToolOutput<Json>, BoxError>> {
373 let tool = self.0.clone();
374 Box::pin(async move { tool.call_raw(ctx, args, resources).await })
375 }
376}
377
378#[derive(Default)]
383pub struct ToolSet<C: BaseContext> {
384 pub set: BTreeMap<String, Arc<dyn DynTool<C>>>,
386}
387
388#[derive(Default)]
390pub struct ToolProviderSet<C: BaseContext> {
391 pub set: BTreeMap<String, Arc<dyn ToolProvider<C>>>,
393}
394
395impl<C> ToolProviderSet<C>
396where
397 C: BaseContext + Clone + Send + Sync + 'static,
398{
399 pub fn new() -> Self {
401 Self {
402 set: BTreeMap::new(),
403 }
404 }
405
406 pub fn contains_provider(&self, name: &str) -> bool {
408 self.set.contains_key(&name.to_ascii_lowercase())
409 }
410
411 pub fn add<T>(&mut self, provider: Arc<T>) -> Result<(), BoxError>
413 where
414 T: ToolProvider<C> + Send + Sync + 'static,
415 {
416 let name = provider.name().to_ascii_lowercase();
417 validate_function_name(&name)?;
418 if self.set.contains_key(&name) {
419 return Err(format!("tool provider {} already exists", name).into());
420 }
421
422 self.set.insert(name, provider);
423 Ok(())
424 }
425
426 pub fn contains_lowercase(&self, lowercase_name: &str) -> bool {
428 self.set
429 .values()
430 .any(|provider| provider.contains_lowercase(lowercase_name))
431 }
432
433 pub fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition> {
435 match names {
436 Some([]) => Vec::new(),
437 _ => {
438 let mut definitions = BTreeMap::new();
439 for provider in self.set.values() {
440 for definition in provider.definitions(names) {
441 definitions
442 .entry(definition.name.to_ascii_lowercase())
443 .or_insert(definition);
444 }
445 }
446 definitions.into_values().collect()
447 }
448 }
449 }
450
451 pub fn groups(&self) -> Vec<ToolGroup> {
453 self.set
454 .values()
455 .flat_map(|provider| provider.groups())
456 .collect()
457 }
458
459 pub fn functions(&self, names: Option<&[String]>) -> Vec<Function> {
461 self.definitions(names)
462 .into_iter()
463 .map(|definition| {
464 let supported_resource_tags = self
465 .set
466 .values()
467 .find(|provider| provider.contains_lowercase(&definition.name))
468 .map(|provider| provider.supported_resource_tags(&definition.name))
469 .unwrap_or_default();
470 Function {
471 definition,
472 supported_resource_tags,
473 }
474 })
475 .collect()
476 }
477
478 pub fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
480 let lowercase_name = name.to_ascii_lowercase();
481 self.set
482 .values()
483 .find(|provider| provider.contains_lowercase(&lowercase_name))
484 .map(|provider| provider.select_resources(&lowercase_name, resources))
485 .unwrap_or_default()
486 }
487
488 pub async fn init_all(&self, ctx: C) -> Result<(), BoxError> {
490 for provider in self.set.values() {
491 provider.init(ctx.clone()).await?;
492 }
493 Ok(())
494 }
495
496 pub async fn refresh_all(&self) -> Result<(), BoxError> {
498 for provider in self.set.values() {
499 provider.refresh().await?;
500 }
501 Ok(())
502 }
503
504 pub async fn call(
506 &self,
507 ctx: C,
508 mut input: ToolInput<Json>,
509 ) -> Result<ToolOutput<Json>, BoxError> {
510 input.name.make_ascii_lowercase();
511 let provider = self
512 .set
513 .values()
514 .find(|provider| provider.contains_lowercase(&input.name))
515 .ok_or_else(|| format!("tool {} not found", input.name))?;
516 provider.call(ctx, input).await
517 }
518}
519
520impl<C> ToolSet<C>
521where
522 C: BaseContext + Send + Sync + 'static,
523{
524 pub fn new() -> Self {
526 Self {
527 set: BTreeMap::new(),
528 }
529 }
530
531 pub fn contains(&self, name: &str) -> bool {
533 self.set.contains_key(&name.to_ascii_lowercase())
534 }
535
536 pub fn contains_lowercase(&self, lowercase_name: &str) -> bool {
538 self.set.contains_key(lowercase_name)
539 }
540
541 pub fn names(&self) -> Vec<String> {
543 self.set.keys().cloned().collect()
544 }
545
546 pub fn groups(&self) -> Vec<ToolGroup> {
553 let mut grouped: BTreeMap<String, (ToolGroupInfo, Vec<String>)> = BTreeMap::new();
554 for (name, tool) in &self.set {
555 if let Some(info) = tool.group() {
556 grouped
557 .entry(info.id.clone())
558 .or_insert_with(|| (info, Vec::new()))
559 .1
560 .push(name.clone());
561 }
562 }
563
564 grouped
565 .into_values()
566 .map(|(info, mut members)| {
567 members.sort();
568 ToolGroup::from_info(info, members)
569 })
570 .collect()
571 }
572
573 pub fn definition(&self, name: &str) -> Option<FunctionDefinition> {
575 self.set
576 .get(&name.to_ascii_lowercase())
577 .map(|tool| tool.definition())
578 }
579
580 pub fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition> {
588 match names {
589 None => self.set.values().map(|tool| tool.definition()).collect(),
590 Some(names) => names
591 .iter()
592 .filter_map(|name| {
593 self.set
594 .get(&name.to_ascii_lowercase())
595 .map(|tool| tool.definition())
596 })
597 .collect(),
598 }
599 }
600
601 pub fn functions(&self, names: Option<&[String]>) -> Vec<Function> {
609 match names {
610 None => self
611 .set
612 .values()
613 .map(|tool| Function {
614 definition: tool.definition(),
615 supported_resource_tags: tool.supported_resource_tags(),
616 })
617 .collect(),
618 Some(names) => names
619 .iter()
620 .filter_map(|name| {
621 self.set
622 .get(&name.to_ascii_lowercase())
623 .map(|tool| Function {
624 definition: tool.definition(),
625 supported_resource_tags: tool.supported_resource_tags(),
626 })
627 })
628 .collect(),
629 }
630 }
631
632 pub fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
634 self.set
635 .get(&name.to_ascii_lowercase())
636 .map(|tool| {
637 let supported_tags = tool.supported_resource_tags();
638 select_resources(resources, &supported_tags)
639 })
640 .unwrap_or_default()
641 }
642
643 pub fn add<T>(&mut self, tool: Arc<T>) -> Result<(), BoxError>
648 where
649 T: Tool<C> + Send + Sync + 'static,
650 {
651 let name = tool.name().to_ascii_lowercase();
652 validate_function_name(&name)?;
653 if self.set.contains_key(&name) {
654 return Err(format!("tool {} already exists", name).into());
655 }
656
657 let tool_dyn = ToolWrapper(tool, PhantomData);
658 self.set.insert(name, Arc::new(tool_dyn));
659 Ok(())
660 }
661
662 pub fn get(&self, name: &str) -> Option<Arc<dyn DynTool<C>>> {
664 self.set.get(&name.to_ascii_lowercase()).cloned()
665 }
666
667 pub fn get_lowercase(&self, lowercase_name: &str) -> Option<Arc<dyn DynTool<C>>> {
669 self.set.get(lowercase_name).cloned()
670 }
671}
672
673#[cfg(test)]
674mod tests {
675 use super::*;
676 use candid::{CandidType, Principal, utils::ArgumentEncoder};
677 use serde_json::json;
678 use std::{sync::Arc, time::Duration};
679
680 use crate::{
681 BaseContext, CacheExpiry, CacheFeatures, CancellationToken, CanisterCaller, HttpFeatures,
682 KeysFeatures, ObjectMeta, Path, PutMode, PutResult, RequestMeta, StateFeatures,
683 StoreFeatures, ToolInput,
684 };
685
686 #[derive(Clone)]
687 struct TestContext {
688 engine_id: Principal,
689 caller: Principal,
690 meta: RequestMeta,
691 cancellation_token: CancellationToken,
692 }
693
694 impl Default for TestContext {
695 fn default() -> Self {
696 Self {
697 engine_id: Principal::management_canister(),
698 caller: Principal::anonymous(),
699 meta: RequestMeta::default(),
700 cancellation_token: CancellationToken::new(),
701 }
702 }
703 }
704
705 impl StateFeatures for TestContext {
706 fn engine_id(&self) -> &Principal {
707 &self.engine_id
708 }
709
710 fn engine_name(&self) -> &str {
711 "test-engine"
712 }
713
714 fn caller(&self) -> &Principal {
715 &self.caller
716 }
717
718 fn meta(&self) -> &RequestMeta {
719 &self.meta
720 }
721
722 fn cancellation_token(&self) -> CancellationToken {
723 self.cancellation_token.clone()
724 }
725
726 fn time_elapsed(&self) -> Duration {
727 Duration::ZERO
728 }
729 }
730
731 impl KeysFeatures for TestContext {
732 async fn a256gcm_key(&self, _derivation_path: Vec<Vec<u8>>) -> Result<[u8; 32], BoxError> {
733 Ok([0; 32])
734 }
735
736 async fn ed25519_sign_message(
737 &self,
738 _derivation_path: Vec<Vec<u8>>,
739 _message: &[u8],
740 ) -> Result<[u8; 64], BoxError> {
741 Ok([0; 64])
742 }
743
744 async fn ed25519_verify(
745 &self,
746 _derivation_path: Vec<Vec<u8>>,
747 _message: &[u8],
748 _signature: &[u8],
749 ) -> Result<(), BoxError> {
750 Ok(())
751 }
752
753 async fn ed25519_public_key(
754 &self,
755 _derivation_path: Vec<Vec<u8>>,
756 ) -> Result<[u8; 32], BoxError> {
757 Ok([0; 32])
758 }
759
760 async fn secp256k1_sign_message_bip340(
761 &self,
762 _derivation_path: Vec<Vec<u8>>,
763 _message: &[u8],
764 ) -> Result<[u8; 64], BoxError> {
765 Ok([0; 64])
766 }
767
768 async fn secp256k1_verify_bip340(
769 &self,
770 _derivation_path: Vec<Vec<u8>>,
771 _message: &[u8],
772 _signature: &[u8],
773 ) -> Result<(), BoxError> {
774 Ok(())
775 }
776
777 async fn secp256k1_sign_message_ecdsa(
778 &self,
779 _derivation_path: Vec<Vec<u8>>,
780 _message: &[u8],
781 ) -> Result<[u8; 64], BoxError> {
782 Ok([0; 64])
783 }
784
785 async fn secp256k1_sign_digest_ecdsa(
786 &self,
787 _derivation_path: Vec<Vec<u8>>,
788 _message_hash: &[u8],
789 ) -> Result<[u8; 64], BoxError> {
790 Ok([0; 64])
791 }
792
793 async fn secp256k1_verify_ecdsa(
794 &self,
795 _derivation_path: Vec<Vec<u8>>,
796 _message_hash: &[u8],
797 _signature: &[u8],
798 ) -> Result<(), BoxError> {
799 Ok(())
800 }
801
802 async fn secp256k1_public_key(
803 &self,
804 _derivation_path: Vec<Vec<u8>>,
805 ) -> Result<[u8; 33], BoxError> {
806 Ok([0; 33])
807 }
808 }
809
810 impl StoreFeatures for TestContext {
811 async fn store_get(&self, _path: &Path) -> Result<(bytes::Bytes, ObjectMeta), BoxError> {
812 Err("not implemented".into())
813 }
814
815 async fn store_list(
816 &self,
817 _prefix: Option<&Path>,
818 _offset: &Path,
819 ) -> Result<Vec<ObjectMeta>, BoxError> {
820 Ok(Vec::new())
821 }
822
823 async fn store_put(
824 &self,
825 _path: &Path,
826 _mode: PutMode,
827 _value: bytes::Bytes,
828 ) -> Result<PutResult, BoxError> {
829 Err("not implemented".into())
830 }
831
832 async fn store_rename_if_not_exists(
833 &self,
834 _from: &Path,
835 _to: &Path,
836 ) -> Result<(), BoxError> {
837 Err("not implemented".into())
838 }
839
840 async fn store_delete(&self, _path: &Path) -> Result<(), BoxError> {
841 Ok(())
842 }
843 }
844
845 impl CacheFeatures for TestContext {
846 fn cache_contains(&self, _key: &str) -> bool {
847 false
848 }
849
850 async fn cache_get<T>(&self, _key: &str) -> Result<T, BoxError>
851 where
852 T: DeserializeOwned,
853 {
854 Err("not implemented".into())
855 }
856
857 async fn cache_get_with<T, F>(&self, _key: &str, _init: F) -> Result<T, BoxError>
858 where
859 T: Sized + DeserializeOwned + Serialize + Send,
860 F: Future<Output = Result<(T, Option<CacheExpiry>), BoxError>> + Send + 'static,
861 {
862 Err("not implemented".into())
863 }
864
865 async fn cache_set<T>(&self, _key: &str, _val: (T, Option<CacheExpiry>))
866 where
867 T: Sized + Serialize + Send,
868 {
869 }
870
871 async fn cache_set_if_not_exists<T>(
872 &self,
873 _key: &str,
874 _val: (T, Option<CacheExpiry>),
875 ) -> bool
876 where
877 T: Sized + Serialize + Send,
878 {
879 false
880 }
881
882 async fn cache_delete(&self, _key: &str) -> bool {
883 false
884 }
885
886 fn cache_raw_iter(
887 &self,
888 ) -> impl Iterator<Item = (Arc<String>, Arc<(bytes::Bytes, Option<CacheExpiry>)>)> {
889 std::iter::empty()
890 }
891 }
892
893 impl HttpFeatures for TestContext {
894 async fn https_call(
895 &self,
896 _url: &str,
897 _method: http::Method,
898 _headers: Option<http::HeaderMap>,
899 _body: Option<Vec<u8>>,
900 ) -> Result<reqwest::Response, BoxError> {
901 Err("not implemented".into())
902 }
903
904 async fn https_signed_call(
905 &self,
906 _url: &str,
907 _method: http::Method,
908 _message_digest: [u8; 32],
909 _headers: Option<http::HeaderMap>,
910 _body: Option<Vec<u8>>,
911 ) -> Result<reqwest::Response, BoxError> {
912 Err("not implemented".into())
913 }
914
915 async fn https_signed_rpc<T>(
916 &self,
917 _endpoint: &str,
918 _method: &str,
919 _args: impl Serialize + Send,
920 ) -> Result<T, BoxError>
921 where
922 T: DeserializeOwned,
923 {
924 Err("not implemented".into())
925 }
926 }
927
928 impl crate::CanisterCaller for TestContext {
929 async fn canister_query<In, Out>(
930 &self,
931 _canister: &Principal,
932 _method: &str,
933 _args: In,
934 ) -> Result<Out, BoxError>
935 where
936 In: ArgumentEncoder + Send,
937 Out: CandidType + for<'a> candid::Deserialize<'a>,
938 {
939 Err("not implemented".into())
940 }
941
942 async fn canister_update<In, Out>(
943 &self,
944 _canister: &Principal,
945 _method: &str,
946 _args: In,
947 ) -> Result<Out, BoxError>
948 where
949 In: ArgumentEncoder + Send,
950 Out: CandidType + for<'a> candid::Deserialize<'a>,
951 {
952 Err("not implemented".into())
953 }
954 }
955
956 impl BaseContext for TestContext {
957 async fn remote_tool_call(
958 &self,
959 _endpoint: &str,
960 _args: ToolInput<Json>,
961 ) -> Result<ToolOutput<Json>, BoxError> {
962 Err("not implemented".into())
963 }
964 }
965
966 struct ExampleTool {
967 id: usize,
968 }
969
970 struct OtherTool;
971
972 #[derive(serde::Deserialize)]
973 struct EchoArgs {
974 value: String,
975 fail: bool,
976 }
977
978 struct TaggedTool;
979
980 struct InvalidTool;
981
982 fn resource(id: u64, tags: &[&str]) -> Resource {
983 Resource {
984 _id: id,
985 name: format!("resource-{id}"),
986 tags: tags.iter().map(|tag| tag.to_string()).collect(),
987 ..Default::default()
988 }
989 }
990
991 impl Tool<TestContext> for ExampleTool {
992 type Args = ();
993 type Output = String;
994
995 fn name(&self) -> String {
996 "example_tool".to_string()
997 }
998
999 fn description(&self) -> String {
1000 "Example tool used for downcast tests".to_string()
1001 }
1002
1003 fn definition(&self) -> FunctionDefinition {
1004 FunctionDefinition {
1005 name: self.name(),
1006 description: self.description(),
1007 parameters: json!({
1008 "type": "object",
1009 "properties": {},
1010 "required": [],
1011 "additionalProperties": false
1012 }),
1013 strict: Some(true),
1014 }
1015 }
1016
1017 async fn call(
1018 &self,
1019 _ctx: TestContext,
1020 _args: Self::Args,
1021 _resources: Vec<Resource>,
1022 ) -> Result<ToolOutput<Self::Output>, BoxError> {
1023 Ok(ToolOutput::new(self.id.to_string()))
1024 }
1025 }
1026
1027 impl Tool<TestContext> for OtherTool {
1028 type Args = ();
1029 type Output = String;
1030
1031 fn name(&self) -> String {
1032 "other_tool".to_string()
1033 }
1034
1035 fn description(&self) -> String {
1036 "Other tool used for downcast tests".to_string()
1037 }
1038
1039 fn definition(&self) -> FunctionDefinition {
1040 FunctionDefinition {
1041 name: self.name(),
1042 description: self.description(),
1043 parameters: json!({
1044 "type": "object",
1045 "properties": {},
1046 "required": [],
1047 "additionalProperties": false
1048 }),
1049 strict: Some(true),
1050 }
1051 }
1052
1053 async fn call(
1054 &self,
1055 _ctx: TestContext,
1056 _args: Self::Args,
1057 _resources: Vec<Resource>,
1058 ) -> Result<ToolOutput<Self::Output>, BoxError> {
1059 Ok(ToolOutput::new("other".to_string()))
1060 }
1061 }
1062
1063 impl Tool<TestContext> for TaggedTool {
1064 type Args = EchoArgs;
1065 type Output = Json;
1066
1067 fn name(&self) -> String {
1068 "tagged_tool".to_string()
1069 }
1070
1071 fn description(&self) -> String {
1072 "Tool that consumes text and code resources".to_string()
1073 }
1074
1075 fn definition(&self) -> FunctionDefinition {
1076 FunctionDefinition {
1077 name: self.name(),
1078 description: self.description(),
1079 parameters: json!({
1080 "type": "object",
1081 "properties": {
1082 "value": {"type": "string"},
1083 "fail": {"type": "boolean"}
1084 },
1085 "required": ["value", "fail"],
1086 "additionalProperties": false
1087 }),
1088 strict: Some(true),
1089 }
1090 }
1091
1092 fn supported_resource_tags(&self) -> Vec<String> {
1093 vec!["text".to_string(), "code".to_string()]
1094 }
1095
1096 async fn call(
1097 &self,
1098 _ctx: TestContext,
1099 args: Self::Args,
1100 resources: Vec<Resource>,
1101 ) -> Result<ToolOutput<Self::Output>, BoxError> {
1102 if args.fail {
1103 return Err("forced failure".into());
1104 }
1105
1106 let mut output = ToolOutput::new(json!({
1107 "value": args.value,
1108 "resources": resources.len(),
1109 }));
1110 output.is_error = Some(false);
1111 Ok(output)
1112 }
1113 }
1114
1115 impl Tool<TestContext> for InvalidTool {
1116 type Args = ();
1117 type Output = String;
1118
1119 fn name(&self) -> String {
1120 "bad-tool".to_string()
1121 }
1122
1123 fn description(&self) -> String {
1124 "Invalid function name".to_string()
1125 }
1126
1127 fn definition(&self) -> FunctionDefinition {
1128 FunctionDefinition {
1129 name: self.name(),
1130 description: self.description(),
1131 parameters: json!({"type": "object"}),
1132 strict: Some(true),
1133 }
1134 }
1135
1136 async fn call(
1137 &self,
1138 _ctx: TestContext,
1139 _args: Self::Args,
1140 _resources: Vec<Resource>,
1141 ) -> Result<ToolOutput<Self::Output>, BoxError> {
1142 Ok(ToolOutput::new(String::new()))
1143 }
1144 }
1145
1146 #[test]
1147 fn dyn_tool_downcast_ref_returns_inner_tool() {
1148 let tool = Arc::new(ExampleTool { id: 7 });
1149 let mut tool_set = ToolSet::<TestContext>::new();
1150 tool_set.add(tool).unwrap();
1151
1152 let dyn_tool = tool_set.get("example_tool").unwrap();
1153 let concrete = dyn_tool.downcast_ref::<ExampleTool>().unwrap();
1154
1155 assert_eq!(concrete.id, 7);
1156 assert!(dyn_tool.downcast_ref::<OtherTool>().is_none());
1157 }
1158
1159 #[test]
1160 fn dyn_tool_downcast_returns_original_arc() {
1161 let tool = Arc::new(ExampleTool { id: 9 });
1162 let mut tool_set = ToolSet::<TestContext>::new();
1163 tool_set.add(tool.clone()).unwrap();
1164
1165 let dyn_tool = tool_set.get("example_tool").unwrap();
1166 let concrete = match dyn_tool.downcast::<ExampleTool>() {
1167 Ok(tool) => tool,
1168 Err(_) => panic!("expected downcast to ExampleTool to succeed"),
1169 };
1170
1171 assert_eq!(concrete.id, 9);
1172 assert!(Arc::ptr_eq(&concrete, &tool));
1173 }
1174
1175 #[test]
1176 fn dyn_tool_downcast_mismatch_returns_original_arc() {
1177 let tool = Arc::new(ExampleTool { id: 11 });
1178 let mut tool_set = ToolSet::<TestContext>::new();
1179 tool_set.add(tool).unwrap();
1180
1181 let dyn_tool = tool_set.get("example_tool").unwrap();
1182 let original = dyn_tool.clone();
1183 let err = match dyn_tool.downcast::<OtherTool>() {
1184 Ok(_) => panic!("expected downcast to OtherTool to fail"),
1185 Err(err) => err,
1186 };
1187
1188 assert!(Arc::ptr_eq(&err, &original));
1189 assert_eq!(err.name(), "example_tool");
1190 }
1191
1192 #[test]
1193 fn fixture_tools_cover_direct_methods() {
1194 futures::executor::block_on(async {
1195 let other = OtherTool;
1196 assert_eq!(other.name(), "other_tool");
1197 assert_eq!(other.description(), "Other tool used for downcast tests");
1198 let definition = other.definition();
1199 assert_eq!(definition.name, "other_tool");
1200 assert_eq!(definition.description, "Other tool used for downcast tests");
1201 assert_eq!(definition.parameters["type"], "object");
1202 let output = other
1203 .call(TestContext::default(), (), Vec::new())
1204 .await
1205 .unwrap();
1206 assert_eq!(output.output, "other");
1207
1208 let invalid = InvalidTool;
1209 assert_eq!(invalid.name(), "bad-tool");
1210 assert_eq!(invalid.description(), "Invalid function name");
1211 let definition = invalid.definition();
1212 assert_eq!(definition.name, "bad-tool");
1213 assert_eq!(definition.description, "Invalid function name");
1214 assert_eq!(definition.parameters["type"], "object");
1215 let output = invalid
1216 .call(TestContext::default(), (), Vec::new())
1217 .await
1218 .unwrap();
1219 assert!(output.output.is_empty());
1220 });
1221 }
1222
1223 #[test]
1224 fn tool_default_methods_call_raw_and_dyn_wrapper_forward_calls() {
1225 futures::executor::block_on(async {
1226 let tool = Arc::new(ExampleTool { id: 42 });
1227 let mut resources = vec![resource(1, &["text"])];
1228
1229 assert!(tool.supported_resource_tags().is_empty());
1230 assert!(tool.select_resources(&mut resources).is_empty());
1231 assert_eq!(resources.len(), 1);
1232 tool.init(TestContext::default()).await.unwrap();
1233
1234 let raw = tool
1235 .call_raw(TestContext::default(), Json::Null, Vec::new())
1236 .await
1237 .unwrap();
1238 assert_eq!(raw.output, json!("42"));
1239 assert_eq!(raw.usage.requests, 1);
1240
1241 let invalid = tool
1242 .call_raw(TestContext::default(), json!({"bad": true}), Vec::new())
1243 .await
1244 .unwrap_err();
1245 assert!(invalid.to_string().contains("invalid args"));
1246
1247 let mut tool_set = ToolSet::<TestContext>::new();
1248 tool_set.add(tool).unwrap();
1249 let dyn_tool = tool_set.get("EXAMPLE_TOOL").unwrap();
1250
1251 assert_eq!(dyn_tool.name(), "example_tool");
1252 assert_eq!(dyn_tool.definition().name, "example_tool");
1253 assert!(dyn_tool.supported_resource_tags().is_empty());
1254 dyn_tool.init(TestContext::default()).await.unwrap();
1255
1256 let output = dyn_tool
1257 .call(TestContext::default(), Json::Null, Vec::new())
1258 .await
1259 .unwrap();
1260 assert_eq!(output.output, json!("42"));
1261 assert_eq!(output.usage.requests, 1);
1262 });
1263 }
1264
1265 #[test]
1266 fn tool_set_registry_filters_resources_and_reports_errors() {
1267 futures::executor::block_on(async {
1268 let mut tool_set = ToolSet::<TestContext>::new();
1269 tool_set.add(Arc::new(ExampleTool { id: 1 })).unwrap();
1270 tool_set.add(Arc::new(TaggedTool)).unwrap();
1271
1272 assert!(tool_set.contains("EXAMPLE_TOOL"));
1273 assert!(tool_set.contains_lowercase("tagged_tool"));
1274 assert!(!tool_set.contains("missing_tool"));
1275 assert_eq!(
1276 tool_set.names(),
1277 vec!["example_tool".to_string(), "tagged_tool".to_string()]
1278 );
1279
1280 let definition = tool_set.definition("TAGGED_TOOL").unwrap();
1281 assert_eq!(definition.name, "tagged_tool");
1282 assert!(tool_set.definition("missing_tool").is_none());
1283
1284 let selected_names = vec!["TAGGED_TOOL".to_string(), "missing_tool".to_string()];
1285 let selected_definitions = tool_set.definitions(Some(&selected_names));
1286 assert_eq!(selected_definitions.len(), 1);
1287 assert_eq!(selected_definitions[0].name, "tagged_tool");
1288 assert_eq!(tool_set.definitions(None).len(), 2);
1289
1290 let selected_functions = tool_set.functions(Some(&selected_names));
1291 assert_eq!(selected_functions.len(), 1);
1292 assert_eq!(
1293 selected_functions[0].supported_resource_tags,
1294 vec!["text".to_string(), "code".to_string()]
1295 );
1296 assert_eq!(tool_set.functions(None).len(), 2);
1297
1298 let mut resources = vec![
1299 resource(1, &["image"]),
1300 resource(2, &["text"]),
1301 resource(3, &["code", "text"]),
1302 resource(4, &["audio"]),
1303 ];
1304 let selected = tool_set.select_resources("TAGGED_TOOL", &mut resources);
1305 assert_eq!(
1306 selected
1307 .iter()
1308 .map(|resource| resource._id)
1309 .collect::<Vec<_>>(),
1310 vec![2, 3]
1311 );
1312 assert_eq!(
1313 resources
1314 .iter()
1315 .map(|resource| resource._id)
1316 .collect::<Vec<_>>(),
1317 vec![1, 4]
1318 );
1319 assert!(
1320 tool_set
1321 .select_resources("missing_tool", &mut resources)
1322 .is_empty()
1323 );
1324
1325 let dyn_tool = tool_set.get_lowercase("tagged_tool").unwrap();
1326 let output = dyn_tool
1327 .call(
1328 TestContext::default(),
1329 json!({"value": "ok", "fail": false}),
1330 vec![resource(9, &["text"])],
1331 )
1332 .await
1333 .unwrap();
1334 assert_eq!(output.output["value"], "ok");
1335 assert_eq!(output.output["resources"], 1);
1336 assert_eq!(output.is_error, Some(false));
1337 assert_eq!(output.usage.requests, 1);
1338 assert!(tool_set.get("missing_tool").is_none());
1339 assert!(tool_set.get_lowercase("missing_tool").is_none());
1340
1341 let failed = dyn_tool
1342 .call(
1343 TestContext::default(),
1344 json!({"value": "bad", "fail": true}),
1345 Vec::new(),
1346 )
1347 .await
1348 .unwrap_err();
1349 assert!(failed.to_string().contains("call failed"));
1350
1351 let duplicate = tool_set.add(Arc::new(ExampleTool { id: 2 })).unwrap_err();
1352 assert!(duplicate.to_string().contains("already exists"));
1353
1354 let invalid = tool_set.add(Arc::new(InvalidTool)).unwrap_err();
1355 assert!(invalid.to_string().contains("invalid character"));
1356 });
1357 }
1358
1359 #[test]
1360 fn test_tool_context_mock_features_cover_default_paths() {
1361 futures::executor::block_on(async {
1362 let ctx = TestContext::default();
1363 assert_eq!(*ctx.engine_id(), Principal::management_canister());
1364 assert_eq!(ctx.engine_name(), "test-engine");
1365 assert_eq!(*ctx.caller(), Principal::anonymous());
1366 assert!(ctx.meta().user.is_none());
1367 assert!(!ctx.cancellation_token().is_cancelled());
1368 assert_eq!(ctx.time_elapsed(), Duration::ZERO);
1369
1370 assert_eq!(ctx.a256gcm_key(Vec::new()).await.unwrap(), [0; 32]);
1371 assert_eq!(
1372 ctx.ed25519_sign_message(Vec::new(), b"message")
1373 .await
1374 .unwrap(),
1375 [0; 64]
1376 );
1377 ctx.ed25519_verify(Vec::new(), b"message", &[0; 64])
1378 .await
1379 .unwrap();
1380 assert_eq!(ctx.ed25519_public_key(Vec::new()).await.unwrap(), [0; 32]);
1381 assert_eq!(
1382 ctx.secp256k1_sign_message_bip340(Vec::new(), b"message")
1383 .await
1384 .unwrap(),
1385 [0; 64]
1386 );
1387 ctx.secp256k1_verify_bip340(Vec::new(), b"message", &[0; 64])
1388 .await
1389 .unwrap();
1390 assert_eq!(
1391 ctx.secp256k1_sign_message_ecdsa(Vec::new(), b"message")
1392 .await
1393 .unwrap(),
1394 [0; 64]
1395 );
1396 assert_eq!(
1397 ctx.secp256k1_sign_digest_ecdsa(Vec::new(), &[0; 32])
1398 .await
1399 .unwrap(),
1400 [0; 64]
1401 );
1402 ctx.secp256k1_verify_ecdsa(Vec::new(), &[0; 32], &[0; 64])
1403 .await
1404 .unwrap();
1405 assert_eq!(ctx.secp256k1_public_key(Vec::new()).await.unwrap(), [0; 33]);
1406
1407 assert!(ctx.store_get(&Path::from("missing")).await.is_err());
1408 assert!(
1409 ctx.store_list(None, &Path::default())
1410 .await
1411 .unwrap()
1412 .is_empty()
1413 );
1414 assert!(
1415 ctx.store_put(&Path::from("file"), PutMode::Overwrite, bytes::Bytes::new())
1416 .await
1417 .is_err()
1418 );
1419 assert!(
1420 ctx.store_rename_if_not_exists(&Path::from("a"), &Path::from("b"))
1421 .await
1422 .is_err()
1423 );
1424 ctx.store_delete(&Path::from("file")).await.unwrap();
1425
1426 assert!(!ctx.cache_contains("key"));
1427 assert!(ctx.cache_get::<String>("key").await.is_err());
1428 assert!(
1429 ctx.cache_get_with("key", async { Ok(("value".to_string(), None)) })
1430 .await
1431 .is_err()
1432 );
1433 ctx.cache_set("key", ("value".to_string(), None)).await;
1434 assert!(
1435 !ctx.cache_set_if_not_exists("key", ("value".to_string(), None))
1436 .await
1437 );
1438 assert!(!ctx.cache_delete("key").await);
1439 assert_eq!(ctx.cache_raw_iter().count(), 0);
1440
1441 assert!(
1442 ctx.https_call("https://example.test", http::Method::GET, None, None)
1443 .await
1444 .is_err()
1445 );
1446 assert!(
1447 ctx.https_signed_call(
1448 "https://example.test",
1449 http::Method::POST,
1450 [0; 32],
1451 None,
1452 None,
1453 )
1454 .await
1455 .is_err()
1456 );
1457 let rpc: Result<String, BoxError> = ctx
1458 .https_signed_rpc("https://example.test", "method", &())
1459 .await;
1460 assert!(rpc.is_err());
1461
1462 let query: Result<String, BoxError> = ctx
1463 .canister_query(&Principal::anonymous(), "query", ())
1464 .await;
1465 assert!(query.is_err());
1466 let update: Result<String, BoxError> = ctx
1467 .canister_update(&Principal::anonymous(), "update", ())
1468 .await;
1469 assert!(update.is_err());
1470
1471 assert!(
1472 ctx.remote_tool_call(
1473 "https://example.test",
1474 ToolInput::new("tool".to_string(), Json::Null),
1475 )
1476 .await
1477 .is_err()
1478 );
1479 });
1480 }
1481
1482 struct GroupedTool {
1483 name: &'static str,
1484 group: &'static str,
1485 }
1486
1487 impl Tool<TestContext> for GroupedTool {
1488 type Args = ();
1489 type Output = String;
1490
1491 fn name(&self) -> String {
1492 self.name.to_string()
1493 }
1494
1495 fn description(&self) -> String {
1496 "Grouped tool fixture".to_string()
1497 }
1498
1499 fn definition(&self) -> FunctionDefinition {
1500 FunctionDefinition {
1501 name: self.name(),
1502 description: self.description(),
1503 parameters: json!({
1504 "type": "object",
1505 "properties": {},
1506 "required": [],
1507 "additionalProperties": false
1508 }),
1509 strict: Some(true),
1510 }
1511 }
1512
1513 fn group(&self) -> Option<ToolGroupInfo> {
1514 Some(ToolGroupInfo {
1515 id: self.group.to_string(),
1516 title: format!("{} title", self.group),
1517 description: format!("{} description", self.group),
1518 instructions: Some(format!("{} instructions", self.group)),
1519 })
1520 }
1521
1522 async fn call(
1523 &self,
1524 _ctx: TestContext,
1525 _args: Self::Args,
1526 _resources: Vec<Resource>,
1527 ) -> Result<ToolOutput<Self::Output>, BoxError> {
1528 Ok(ToolOutput::new(String::new()))
1529 }
1530 }
1531
1532 #[test]
1533 fn tool_set_groups_aggregate_members_by_id() {
1534 let mut tool_set = ToolSet::<TestContext>::new();
1535 tool_set
1536 .add(Arc::new(GroupedTool {
1537 name: "fs_write",
1538 group: "fs",
1539 }))
1540 .unwrap();
1541 tool_set
1542 .add(Arc::new(GroupedTool {
1543 name: "fs_read",
1544 group: "fs",
1545 }))
1546 .unwrap();
1547 tool_set
1548 .add(Arc::new(GroupedTool {
1549 name: "mem_get",
1550 group: "memory",
1551 }))
1552 .unwrap();
1553 tool_set.add(Arc::new(ExampleTool { id: 1 })).unwrap();
1555
1556 let groups = tool_set.groups();
1557 assert_eq!(groups.len(), 2);
1558
1559 let fs = groups.iter().find(|group| group.id == "fs").unwrap();
1560 assert_eq!(
1562 fs.members,
1563 vec!["fs_read".to_string(), "fs_write".to_string()]
1564 );
1565 assert_eq!(fs.title, "fs title");
1566 assert_eq!(fs.instructions.as_deref(), Some("fs instructions"));
1567
1568 let memory = groups.iter().find(|group| group.id == "memory").unwrap();
1569 assert_eq!(memory.members, vec!["mem_get".to_string()]);
1570 }
1571}