Skip to main content

openai_core/helpers/
mod.rs

1//! Structured output 与工具调用辅助能力。
2
3#[cfg(feature = "tool-runner")]
4use std::collections::BTreeMap;
5#[cfg(feature = "tool-runner")]
6use std::future::Future;
7#[cfg(feature = "tool-runner")]
8use std::pin::Pin;
9#[cfg(feature = "tool-runner")]
10use std::sync::Arc;
11
12use schemars::{JsonSchema, schema_for};
13use serde::de::DeserializeOwned;
14use serde_json::Value;
15
16use crate::error::{Error, Result};
17#[cfg(feature = "tool-runner")]
18use crate::json_payload::JsonPayload;
19use crate::resources::{ChatCompletion, Response};
20
21/// 返回指定类型对应的 JSON Schema。
22pub fn json_schema_for<T>() -> Value
23where
24    T: JsonSchema,
25{
26    serde_json::to_value(schema_for!(T)).unwrap_or_else(|_| Value::Object(Default::default()))
27}
28
29/// 尝试从文本中提取并解析 JSON。
30///
31/// 该函数会自动去掉常见的 Markdown 代码块包裹。
32///
33/// # Errors
34///
35/// 当 JSON 解析失败时返回错误。
36pub fn parse_json_payload<T>(payload: &str) -> Result<T>
37where
38    T: DeserializeOwned,
39{
40    let trimmed = payload.trim();
41    let normalized = trimmed
42        .strip_prefix("```json")
43        .or_else(|| trimmed.strip_prefix("```"))
44        .map(|value| value.trim())
45        .and_then(|value| value.strip_suffix("```"))
46        .map_or(trimmed, str::trim);
47
48    serde_json::from_str(normalized).map_err(|error| {
49        Error::Serialization(crate::SerializationError::new(format!(
50            "结构化 JSON 解析失败: {error}"
51        )))
52    })
53}
54
55/// 表示已经解析出结构化对象的聊天补全结果。
56#[derive(Debug, Clone)]
57pub struct ParsedChatCompletion<T> {
58    /// 原始聊天补全结果。
59    pub response: ChatCompletion,
60    /// 反序列化后的结构化对象。
61    pub parsed: T,
62}
63
64/// 表示已经解析出结构化对象的 Responses 结果。
65#[derive(Debug, Clone)]
66pub struct ParsedResponse<T> {
67    /// 原始 Responses 结果。
68    pub response: Response,
69    /// 反序列化后的结构化对象。
70    pub parsed: T,
71}
72
73/// 工具处理函数的异步返回值类型。
74#[cfg(feature = "tool-runner")]
75#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
76pub type ToolFuture = Pin<Box<dyn Future<Output = Result<Value>> + Send>>;
77
78/// 表示工具处理器。
79#[cfg(feature = "tool-runner")]
80#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
81pub trait ToolHandler: Send + Sync {
82    /// 执行一个工具调用。
83    fn call(&self, arguments: Value) -> ToolFuture;
84}
85
86#[cfg(feature = "tool-runner")]
87impl<F, Fut> ToolHandler for F
88where
89    F: Fn(Value) -> Fut + Send + Sync,
90    Fut: Future<Output = Result<Value>> + Send + 'static,
91{
92    fn call(&self, arguments: Value) -> ToolFuture {
93        Box::pin((self)(arguments))
94    }
95}
96
97/// 表示单个工具定义。
98#[cfg(feature = "tool-runner")]
99#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
100#[derive(Clone)]
101pub struct ToolDefinition {
102    /// 工具名称。
103    pub name: String,
104    /// 工具描述。
105    pub description: Option<String>,
106    /// 工具参数 JSON Schema。
107    pub parameters: JsonPayload,
108    handler: Arc<dyn ToolHandler>,
109}
110
111#[cfg(feature = "tool-runner")]
112impl std::fmt::Debug for ToolDefinition {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("ToolDefinition")
115            .field("name", &self.name)
116            .field("description", &self.description)
117            .field("parameters", &self.parameters)
118            .finish()
119    }
120}
121
122#[cfg(feature = "tool-runner")]
123impl ToolDefinition {
124    /// 使用显式 JSON Schema 创建工具定义。
125    pub fn new<T, U, H>(
126        name: T,
127        description: Option<U>,
128        parameters: impl Into<JsonPayload>,
129        handler: H,
130    ) -> Self
131    where
132        T: Into<String>,
133        U: Into<String>,
134        H: ToolHandler + 'static,
135    {
136        Self {
137            name: name.into(),
138            description: description.map(Into::into),
139            parameters: parameters.into(),
140            handler: Arc::new(handler),
141        }
142    }
143
144    /// 使用 `schemars` 自动推导参数 Schema。
145    pub fn from_schema<TArgs, T, U, H>(name: T, description: Option<U>, handler: H) -> Self
146    where
147        TArgs: JsonSchema,
148        T: Into<String>,
149        U: Into<String>,
150        H: ToolHandler + 'static,
151    {
152        Self {
153            name: name.into(),
154            description: description.map(Into::into),
155            parameters: json_schema_for::<TArgs>().into(),
156            handler: Arc::new(handler),
157        }
158    }
159
160    /// 调用工具处理器。
161    pub async fn invoke(&self, arguments: Value) -> Result<Value> {
162        self.handler.call(arguments).await
163    }
164}
165
166/// 表示工具注册表。
167#[cfg(feature = "tool-runner")]
168#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
169#[derive(Debug, Clone, Default)]
170pub struct ToolRegistry {
171    tools: BTreeMap<String, ToolDefinition>,
172}
173
174#[cfg(feature = "tool-runner")]
175impl ToolRegistry {
176    /// 创建空的工具注册表。
177    pub fn new() -> Self {
178        Self::default()
179    }
180
181    /// 注册一个工具。
182    pub fn register(&mut self, tool: ToolDefinition) {
183        self.tools.insert(tool.name.clone(), tool);
184    }
185
186    /// 查询指定名称的工具。
187    pub fn get(&self, name: &str) -> Option<&ToolDefinition> {
188        self.tools.get(name)
189    }
190
191    /// 返回所有工具定义。
192    pub fn all(&self) -> impl Iterator<Item = &ToolDefinition> {
193        self.tools.values()
194    }
195
196    /// 判断注册表是否为空。
197    pub fn is_empty(&self) -> bool {
198        self.tools.is_empty()
199    }
200}