#[cfg(feature = "tool-runner")]
use std::collections::BTreeMap;
#[cfg(feature = "tool-runner")]
use std::future::Future;
#[cfg(feature = "tool-runner")]
use std::pin::Pin;
#[cfg(feature = "tool-runner")]
use std::sync::Arc;
use schemars::{JsonSchema, schema_for};
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::error::{Error, Result};
#[cfg(feature = "tool-runner")]
use crate::json_payload::JsonPayload;
use crate::resources::{ChatCompletion, Response};
pub fn json_schema_for<T>() -> Value
where
T: JsonSchema,
{
serde_json::to_value(schema_for!(T)).unwrap_or_else(|_| Value::Object(Default::default()))
}
pub fn parse_json_payload<T>(payload: &str) -> Result<T>
where
T: DeserializeOwned,
{
let trimmed = payload.trim();
let normalized = trimmed
.strip_prefix("```json")
.or_else(|| trimmed.strip_prefix("```"))
.map(|value| value.trim())
.and_then(|value| value.strip_suffix("```"))
.map_or(trimmed, str::trim);
serde_json::from_str(normalized).map_err(|error| {
Error::Serialization(crate::SerializationError::new(format!(
"结构化 JSON 解析失败: {error}"
)))
})
}
#[derive(Debug, Clone)]
pub struct ParsedChatCompletion<T> {
pub response: ChatCompletion,
pub parsed: T,
}
#[derive(Debug, Clone)]
pub struct ParsedResponse<T> {
pub response: Response,
pub parsed: T,
}
#[cfg(feature = "tool-runner")]
#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
pub type ToolFuture = Pin<Box<dyn Future<Output = Result<Value>> + Send>>;
#[cfg(feature = "tool-runner")]
#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
pub trait ToolHandler: Send + Sync {
fn call(&self, arguments: Value) -> ToolFuture;
}
#[cfg(feature = "tool-runner")]
impl<F, Fut> ToolHandler for F
where
F: Fn(Value) -> Fut + Send + Sync,
Fut: Future<Output = Result<Value>> + Send + 'static,
{
fn call(&self, arguments: Value) -> ToolFuture {
Box::pin((self)(arguments))
}
}
#[cfg(feature = "tool-runner")]
#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
#[derive(Clone)]
pub struct ToolDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: JsonPayload,
handler: Arc<dyn ToolHandler>,
}
#[cfg(feature = "tool-runner")]
impl std::fmt::Debug for ToolDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolDefinition")
.field("name", &self.name)
.field("description", &self.description)
.field("parameters", &self.parameters)
.finish()
}
}
#[cfg(feature = "tool-runner")]
impl ToolDefinition {
pub fn new<T, U, H>(
name: T,
description: Option<U>,
parameters: impl Into<JsonPayload>,
handler: H,
) -> Self
where
T: Into<String>,
U: Into<String>,
H: ToolHandler + 'static,
{
Self {
name: name.into(),
description: description.map(Into::into),
parameters: parameters.into(),
handler: Arc::new(handler),
}
}
pub fn from_schema<TArgs, T, U, H>(name: T, description: Option<U>, handler: H) -> Self
where
TArgs: JsonSchema,
T: Into<String>,
U: Into<String>,
H: ToolHandler + 'static,
{
Self {
name: name.into(),
description: description.map(Into::into),
parameters: json_schema_for::<TArgs>().into(),
handler: Arc::new(handler),
}
}
pub async fn invoke(&self, arguments: Value) -> Result<Value> {
self.handler.call(arguments).await
}
}
#[cfg(feature = "tool-runner")]
#[cfg_attr(docsrs, doc(cfg(feature = "tool-runner")))]
#[derive(Debug, Clone, Default)]
pub struct ToolRegistry {
tools: BTreeMap<String, ToolDefinition>,
}
#[cfg(feature = "tool-runner")]
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, tool: ToolDefinition) {
self.tools.insert(tool.name.clone(), tool);
}
pub fn get(&self, name: &str) -> Option<&ToolDefinition> {
self.tools.get(name)
}
pub fn all(&self) -> impl Iterator<Item = &ToolDefinition> {
self.tools.values()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}