plainllm 1.2.0

A plain & simple LLM client
Documentation
use super::chat_completion::{FunctionDefinition, Parameters, Tool};
use futures::{FutureExt, future::BoxFuture};
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;

//////////////////////////////
// 2. ToolRegistry
//////////////////////////////

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionTool {
    #[serde(rename = "type")]
    tool_type: String,
    function: FunctionDef,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionDef {
    name: String,
    description: String,
    parameters: ParametersDef,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ParametersDef {
    #[serde(rename = "type")]
    param_type: String,
    properties: HashMap<String, PropertyDef>,
    required: Vec<String>,
    #[serde(default)]
    additional_properties: bool,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PropertyDef {
    #[serde(rename = "type")]
    prop_type: String,
    description: String,
}

impl FunctionTool {
    pub fn from_type<T: JsonSchema>(name: &str, description: &str) -> Self {
        let schema = schema_for!(T);
        let val = serde_json::to_value(schema).unwrap_or_default();
        let param_type = val
            .get("type")
            .and_then(|v| v.as_str())
            .unwrap_or("object")
            .to_string();
        let mut properties = HashMap::new();
        if let Some(Value::Object(map)) = val.get("properties") {
            for (k, v) in map {
                let t = v
                    .get("type")
                    .and_then(|v| v.as_str())
                    .unwrap_or("object")
                    .to_string();
                let desc = v
                    .get("description")
                    .and_then(|v| v.as_str())
                    .unwrap_or("")
                    .to_string();
                properties.insert(
                    k.clone(),
                    PropertyDef {
                        prop_type: t,
                        description: desc,
                    },
                );
            }
        }
        let mut required = Vec::new();
        if let Some(Value::Array(req)) = val.get("required") {
            for r in req {
                if let Some(s) = r.as_str() {
                    required.push(s.to_string());
                }
            }
        }
        FunctionTool {
            tool_type: "function".to_string(),
            function: FunctionDef {
                name: name.to_string(),
                description: description.to_string(),
                parameters: ParametersDef {
                    param_type,
                    properties,
                    required,
                    additional_properties: false,
                },
            },
        }
    }
}

/// Type-safe parameter builder with phantom data
pub struct ParameterBuilder<F>(FunctionToolBuilder, std::marker::PhantomData<F>);

impl<F> ParameterBuilder<F> {
    pub fn param<T: Into<String>>(
        mut self,
        name: T,
        type_name: T,
        description: T,
        required: bool,
    ) -> Self {
        let name_str = name.into();
        self.0.properties.insert(
            name_str.clone(),
            PropertyDef {
                prop_type: type_name.into(),
                description: description.into(),
            },
        );
        if required {
            self.0.required.push(name_str);
        }
        self
    }

    pub fn build(self) -> FunctionTool {
        self.0.build_internal()
    }
}

/// Main builder for function tools
pub struct FunctionToolBuilder {
    name: String,
    description: String,
    properties: HashMap<String, PropertyDef>,
    required: Vec<String>,
}

impl FunctionToolBuilder {
    pub fn new<T: Into<String>>(name: T, description: T) -> Self {
        Self {
            name: name.into(),
            description: description.into(),
            properties: HashMap::new(),
            required: Vec::new(),
        }
    }

    pub fn params<F>(self) -> ParameterBuilder<F> {
        ParameterBuilder(self, std::marker::PhantomData)
    }

    fn build_internal(self) -> FunctionTool {
        FunctionTool {
            tool_type: "function".to_string(),
            function: FunctionDef {
                name: self.name,
                description: self.description,
                parameters: ParametersDef {
                    param_type: "object".to_string(),
                    properties: self.properties,
                    required: self.required,
                    additional_properties: false,
                },
            },
        }
    }
}

pub struct ToolRegistry {
    tools: HashMap<
        String,
        Box<dyn Fn(Value) -> BoxFuture<'static, Result<Value, String>> + Send + Sync>,
    >,
    definitions: HashMap<String, FunctionTool>,
}

impl Default for ToolRegistry {
    fn default() -> Self {
        Self::new()
    }
}

impl ToolRegistry {
    pub fn new() -> Self {
        Self {
            tools: HashMap::new(),
            definitions: HashMap::new(),
        }
    }

    pub fn register<F, Fut, Args, Ret>(&mut self, tool: FunctionTool, f: F) -> &mut Self
    where
        F: Fn(Args) -> Fut + Send + Sync + Clone + 'static,
        Fut: std::future::Future<Output = Ret> + Send + 'static,
        Args: for<'de> Deserialize<'de> + Send + 'static,
        Ret: Serialize + 'static,
    {
        let name = tool.function.name.clone();
        self.tools.insert(
            name.clone(),
            Box::new(move |args| {
                let f = f.clone();
                let args: Result<Args, _> = serde_json::from_value(args);
                async move {
                    let args =
                        args.map_err(|e| format!("Failed to deserialize arguments: {}", e))?;
                    let result = f(args).await;
                    serde_json::to_value(result)
                        .map_err(|e| format!("Failed to serialize result: {}", e))
                }
                .boxed()
            }),
        );
        self.definitions.insert(name, tool);
        self
    }

    pub fn register_with_callbacks<F, Fut, Args, Ret, C1, C2>(
        &mut self,
        tool: FunctionTool,
        f: F,
        on_call: Option<C1>,
        on_result: Option<C2>,
    ) -> &mut Self
    where
        F: Fn(Args) -> Fut + Send + Sync + Clone + 'static,
        Fut: std::future::Future<Output = Ret> + Send + 'static,
        Args: for<'de> Deserialize<'de> + Clone + Send + 'static,
        Ret: Serialize + Clone + 'static,
        C1: Fn(&Args) + Send + Sync + Clone + 'static,
        C2: Fn(&Args, &Ret) + Send + Sync + Clone + 'static,
    {
        let name = tool.function.name.clone();
        self.tools.insert(
            name.clone(),
            Box::new(move |args| {
                let f = f.clone();
                let on_call = on_call.clone();
                let on_result = on_result.clone();
                let args: Result<Args, _> = serde_json::from_value(args);
                async move {
                    let args =
                        args.map_err(|e| format!("Failed to deserialize arguments: {}", e))?;
                    if let Some(cb) = on_call.as_ref() {
                        cb(&args);
                    }
                    let result = f(args.clone()).await;
                    if let Some(cb) = on_result.as_ref() {
                        cb(&args, &result);
                    }
                    serde_json::to_value(result)
                        .map_err(|e| format!("Failed to serialize result: {}", e))
                }
                .boxed()
            }),
        );
        self.definitions.insert(name, tool);
        self
    }

    pub fn get_definitions(&self) -> Vec<&FunctionTool> {
        self.definitions.values().collect()
    }

    pub async fn call(&self, name: &str, args: Value) -> Result<Value, String> {
        tracing::info!("tool {}", name);
        tracing::debug!("args: {}", args);
        match self.tools.get(name) {
            Some(f) => f(args).await,
            None => Err(format!("Tool '{}' not found", name)),
        }
    }

    pub fn to_api_tools(&self) -> Vec<Tool> {
        self.get_definitions()
            .into_iter()
            .map(|tool| Tool {
                tool_type: tool.tool_type.clone(),
                function: FunctionDefinition {
                    name: tool.function.name.clone(),
                    description: tool.function.description.clone(),
                    parameters: Parameters {
                        param_type: tool.function.parameters.param_type.clone(),
                        properties: serde_json::to_value(&tool.function.parameters.properties)
                            .unwrap_or_else(|_| serde_json::json!({})),
                        required: tool.function.parameters.required.clone(),
                        additional_properties: tool.function.parameters.additional_properties,
                    },
                },
            })
            .collect()
    }
}