use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: InputSchema,
#[serde(default)]
pub source: ToolSource,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub annotations: HashMap<String, String>,
}
impl ToolDefinition {
#[must_use]
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
input_schema: InputSchema,
) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema,
source: ToolSource::Inline,
annotations: HashMap::new(),
}
}
#[must_use]
pub fn with_source(mut self, source: ToolSource) -> Self {
self.source = source;
self
}
#[must_use]
pub fn with_annotation(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.annotations.insert(key.into(), value.into());
self
}
#[must_use]
pub fn is_mcp(&self) -> bool {
matches!(self.source, ToolSource::Mcp { .. })
}
#[must_use]
pub fn is_openapi(&self) -> bool {
matches!(self.source, ToolSource::OpenApi { .. })
}
#[must_use]
pub fn is_graphql(&self) -> bool {
matches!(self.source, ToolSource::GraphQl { .. })
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolSource {
Mcp {
server_name: String,
server_uri: String,
},
OpenApi {
spec_path: String,
operation_id: String,
method: String,
path: String,
},
GraphQl {
endpoint: String,
operation_name: String,
operation_type: GraphQlOperationType,
},
#[default]
Inline,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum GraphQlOperationType {
#[default]
Query,
Mutation,
Subscription,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct InputSchema {
#[serde(flatten)]
pub schema: serde_json::Value,
}
impl InputSchema {
#[must_use]
pub fn from_json_schema(schema: serde_json::Value) -> Self {
Self { schema }
}
#[must_use]
pub fn empty() -> Self {
Self {
schema: serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
}
}
#[must_use]
pub fn required_properties(&self) -> Vec<String> {
self.schema
.get("required")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default()
}
#[must_use]
pub fn properties(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
self.schema.get("properties").and_then(|v| v.as_object())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub call_id: String,
pub tool_name: String,
pub arguments: serde_json::Value,
}
impl ToolCall {
#[must_use]
pub fn new(tool_name: impl Into<String>, arguments: serde_json::Value) -> Self {
Self {
call_id: generate_call_id(),
tool_name: tool_name.into(),
arguments,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub call_id: String,
pub content: ToolResultContent,
pub is_error: bool,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, String>,
}
impl ToolResult {
#[must_use]
pub fn text(call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
call_id: call_id.into(),
content: ToolResultContent::Text(content.into()),
is_error: false,
metadata: HashMap::new(),
}
}
#[must_use]
pub fn error(call_id: impl Into<String>, message: impl Into<String>) -> Self {
Self {
call_id: call_id.into(),
content: ToolResultContent::Text(message.into()),
is_error: true,
metadata: HashMap::new(),
}
}
#[must_use]
pub fn as_text(&self) -> Option<&str> {
match &self.content {
ToolResultContent::Text(s) => Some(s),
ToolResultContent::Json(v) => v.as_str(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolResultContent {
Text(String),
Json(serde_json::Value),
}
fn generate_call_id() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let count = COUNTER.fetch_add(1, Ordering::Relaxed);
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0);
format!("call-{timestamp:x}-{count:x}")
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_definition_creation() {
let tool = ToolDefinition::new("test_tool", "A test tool", InputSchema::empty());
assert_eq!(tool.name, "test_tool");
assert!(matches!(tool.source, ToolSource::Inline));
}
#[test]
fn test_tool_call_creation() {
let call = ToolCall::new("test_tool", json!({"input": "hello"}));
assert_eq!(call.tool_name, "test_tool");
assert!(call.call_id.starts_with("call-"));
}
#[test]
fn test_tool_result() {
let result = ToolResult::text("call-123", "Success");
assert!(!result.is_error);
assert_eq!(result.as_text(), Some("Success"));
}
}