ds_api/tool_trait.rs
1use crate::raw::request::tool::Tool as RawTool;
2use async_trait::async_trait;
3use serde_json::Value;
4use serde_json::json;
5use std::collections::HashMap;
6
7/// The core trait that all agent tools must implement.
8///
9/// You should not implement this trait manually. Instead, annotate your `impl` block
10/// with the [`#[tool]`][ds_api_macros::tool] macro and write plain `async fn` methods —
11/// the macro generates the `raw_tools` and `call` implementations for you.
12///
13/// # What the macro generates
14///
15/// For each `async fn` in the annotated `impl`:
16/// - A [`RawTool`] entry (name, description from doc comment, JSON Schema from parameter types)
17/// is added to the `raw_tools()` vec.
18/// - A `match` arm in `call()` that deserialises each argument from the incoming `args` JSON,
19/// invokes the method, and serialises the return value via `serde_json::to_value`.
20///
21/// Any return type that implements `serde::Serialize` is accepted — `serde_json::Value`,
22/// plain structs with `#[derive(Serialize)]`, primitives, `Option<T>`, `Vec<T>`, etc.
23///
24/// # Example
25///
26/// ```no_run
27/// use ds_api::{DeepseekAgent, tool};
28/// use serde_json::{Value, json};
29///
30/// struct Calc;
31///
32/// #[tool]
33/// impl ds_api::Tool for Calc {
34/// /// Add two integers together.
35/// /// a: first operand
36/// /// b: second operand
37/// async fn add(&self, a: i64, b: i64) -> i64 {
38/// a + b
39/// }
40/// }
41///
42/// # #[tokio::main] async fn main() {
43/// let agent = DeepseekAgent::new("sk-...").add_tool(Calc);
44/// # }
45/// ```
46#[async_trait]
47pub trait Tool: Send + Sync {
48 /// Return the list of raw tool definitions to send to the API.
49 fn raw_tools(&self) -> Vec<RawTool>;
50
51 /// Invoke the named tool with the given arguments and return the result as a JSON value.
52 ///
53 /// When using the `#[tool]` macro you do not implement this method yourself —
54 /// the macro generates it. The generated implementation accepts any return type
55 /// that implements `serde::Serialize` (including `serde_json::Value`, plain
56 /// structs with `#[derive(Serialize)]`, primitives, etc.) and converts the
57 /// value to `serde_json::Value` automatically.
58 async fn call(&self, name: &str, args: Value) -> Value;
59}
60
61/// 将多个 Tool 合并为一个,方便批量注册进 agent。
62pub struct ToolBundle {
63 tools: Vec<Box<dyn Tool>>,
64 index: std::collections::HashMap<String, usize>,
65}
66
67impl ToolBundle {
68 pub fn new() -> Self {
69 Self {
70 tools: vec![],
71 index: HashMap::new(),
72 }
73 }
74
75 pub fn add<T: Tool + 'static>(mut self, tool: T) -> Self {
76 let idx = self.tools.len();
77 for raw in tool.raw_tools() {
78 self.index.insert(raw.function.name.clone(), idx);
79 }
80 self.tools.push(Box::new(tool));
81 self
82 }
83}
84
85#[async_trait]
86impl Tool for ToolBundle {
87 fn raw_tools(&self) -> Vec<RawTool> {
88 self.tools.iter().flat_map(|t| t.raw_tools()).collect()
89 }
90
91 async fn call(&self, name: &str, args: Value) -> Value {
92 match self.index.get(name) {
93 Some(&idx) => self.tools[idx].call(name, args).await,
94 None => json!({ "error": format!("未知工具: {name}") }),
95 }
96 }
97}