use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use futures::future::BoxFuture;
use serde_json::Value;
use crate::core::{DynTool, ToolContext};
use crate::error::Result;
use crate::genai_types::{FunctionDeclaration, Schema};
pub(crate) type FunctionToolFn = Arc<
dyn for<'a> Fn(Value, &'a mut ToolContext) -> BoxFuture<'a, Result<Value>>
+ Send
+ Sync
+ 'static,
>;
pub struct FunctionTool {
name: String,
description: String,
parameters: Option<Schema>,
long_running: bool,
require_confirmation: bool,
confirmation_hint: Option<String>,
f: FunctionToolFn,
}
impl std::fmt::Debug for FunctionTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FunctionTool")
.field("name", &self.name)
.field("description", &self.description)
.field("long_running", &self.long_running)
.finish_non_exhaustive()
}
}
impl FunctionTool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: Option<Schema>,
f: FunctionToolFn,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
long_running: false,
require_confirmation: false,
confirmation_hint: None,
f,
}
}
#[must_use]
pub fn with_long_running(mut self, yes: bool) -> Self {
self.long_running = yes;
self
}
#[must_use]
pub fn require_confirmation(mut self, yes: bool) -> Self {
self.require_confirmation = yes;
self
}
#[must_use]
pub fn with_confirmation_hint(mut self, hint: impl Into<String>) -> Self {
self.confirmation_hint = Some(hint.into());
self
}
pub fn from_async<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
parameters: Option<Schema>,
f: F,
) -> Self
where
F: for<'a> Fn(Value, &'a mut ToolContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Value>> + Send + 'static,
{
let f = Arc::new(f);
let boxed: FunctionToolFn = Arc::new(move |v, ctx| {
let f = f.clone();
let fut = f(v, ctx);
Box::pin(fut) as Pin<Box<dyn std::future::Future<Output = _> + Send>>
});
Self::new(name, description, parameters, boxed)
}
}
#[async_trait]
impl DynTool for FunctionTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn is_long_running(&self) -> bool {
self.long_running
}
fn requires_confirmation(&self, _args: &Value) -> bool {
self.require_confirmation
}
fn confirmation_hint(&self, _args: &Value) -> String {
self.confirmation_hint
.clone()
.unwrap_or_else(|| format!("Approve execution of tool `{}`?", self.name))
}
fn declaration(&self) -> Option<FunctionDeclaration> {
Some(
FunctionDeclaration::new(&self.name, &self.description)
.with_parameters(self.parameters.clone().unwrap_or_else(Schema::object)),
)
}
async fn run(&self, args: Value, ctx: &mut ToolContext) -> Result<Value> {
(self.f)(args, ctx).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use adk_services_mem_for_tests::dummy_invocation;
use serde_json::json;
#[tokio::test]
async fn echo_tool_runs() {
let t = FunctionTool::from_async(
"echo",
"echo the args",
Some(
Schema::object()
.property("msg", Schema::string())
.require("msg"),
),
|args: Value, _ctx: &mut ToolContext| async move { Ok(args) },
);
let inv = dummy_invocation();
let mut ctx = ToolContext::new(Arc::new(inv));
let r = t.run(json!({"msg": "hi"}), &mut ctx).await.unwrap();
assert_eq!(r["msg"], "hi");
assert!(t.declaration().unwrap().parameters.is_some());
}
mod adk_services_mem_for_tests {
use crate::core::InvocationContext;
pub(super) fn dummy_invocation() -> InvocationContext {
crate::core::testing::test_invocation_context()
}
}
}