Skip to main content

ds_api/
tool_trait.rs

1use crate::raw::request::tool::Tool as RawTool;
2use async_trait::async_trait;
3use serde_json::Value;
4use serde_json::json;
5use std::collections::HashMap;
6
7/// The core trait that all agent tools must implement.
8///
9/// You should not implement this trait manually. Instead, annotate your `impl` block
10/// with the [`#[tool]`][ds_api_macros::tool] macro and write plain `async fn` methods —
11/// the macro generates the `raw_tools` and `call` implementations for you.
12///
13/// # What the macro generates
14///
15/// For each `async fn` in the annotated `impl`:
16/// - A [`RawTool`] entry (name, description from doc comment, JSON Schema from parameter types)
17///   is added to the `raw_tools()` vec.
18/// - A `match` arm in `call()` that deserialises each argument from the incoming `args` JSON,
19///   invokes the method, and serialises the return value via `serde_json::to_value`.
20///
21/// Any return type that implements `serde::Serialize` is accepted — `serde_json::Value`,
22/// plain structs with `#[derive(Serialize)]`, primitives, `Option<T>`, `Vec<T>`, etc.
23///
24/// # Example
25///
26/// ```no_run
27/// use ds_api::{DeepseekAgent, tool};
28/// use serde_json::{Value, json};
29///
30/// struct Calc;
31///
32/// #[tool]
33/// impl ds_api::Tool for Calc {
34///     /// Add two integers together.
35///     /// a: first operand
36///     /// b: second operand
37///     async fn add(&self, a: i64, b: i64) -> i64 {
38///         a + b
39///     }
40/// }
41///
42/// # #[tokio::main] async fn main() {
43/// let agent = DeepseekAgent::new("sk-...").add_tool(Calc);
44/// # }
45/// ```
46#[async_trait]
47pub trait Tool: Send + Sync {
48    /// Return the list of raw tool definitions to send to the API.
49    fn raw_tools(&self) -> Vec<RawTool>;
50
51    /// Invoke the named tool with the given arguments and return the result as a JSON value.
52    ///
53    /// When using the `#[tool]` macro you do not implement this method yourself —
54    /// the macro generates it. The generated implementation accepts any return type
55    /// that implements `serde::Serialize` (including `serde_json::Value`, plain
56    /// structs with `#[derive(Serialize)]`, primitives, etc.) and converts the
57    /// value to `serde_json::Value` automatically.
58    async fn call(&self, name: &str, args: Value) -> Value;
59}
60
61/// 将多个 Tool 合并为一个,方便批量注册进 agent。
62pub struct ToolBundle {
63    tools: Vec<Box<dyn Tool>>,
64    index: std::collections::HashMap<String, usize>,
65}
66
67impl ToolBundle {
68    pub fn new() -> Self {
69        Self {
70            tools: vec![],
71            index: HashMap::new(),
72        }
73    }
74
75    pub fn add<T: Tool + 'static>(mut self, tool: T) -> Self {
76        let idx = self.tools.len();
77        for raw in tool.raw_tools() {
78            self.index.insert(raw.function.name.clone(), idx);
79        }
80        self.tools.push(Box::new(tool));
81        self
82    }
83}
84
85#[async_trait]
86impl Tool for ToolBundle {
87    fn raw_tools(&self) -> Vec<RawTool> {
88        self.tools.iter().flat_map(|t| t.raw_tools()).collect()
89    }
90
91    async fn call(&self, name: &str, args: Value) -> Value {
92        match self.index.get(name) {
93            Some(&idx) => self.tools[idx].call(name, args).await,
94            None => json!({ "error": format!("未知工具: {name}") }),
95        }
96    }
97}