lutum-protocol 0.1.0

Core traits and request/response types for lutum
Documentation
use std::fmt;

use schemars::{JsonSchema, Schema, schema_for};
use serde::{Serialize, de::DeserializeOwned};
use thiserror::Error;

use crate::conversation::{ToolMetadata, ToolUse};

#[derive(Clone, Copy)]
pub struct ToolDef {
    pub name: &'static str,
    pub description: &'static str,
    schema: fn() -> Schema,
}

impl fmt::Debug for ToolDef {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("ToolDef")
            .field("name", &self.name)
            .field("description", &self.description)
            .finish_non_exhaustive()
    }
}

impl ToolDef {
    pub const fn new(
        name: &'static str,
        description: &'static str,
        schema: fn() -> Schema,
    ) -> Self {
        Self {
            name,
            description,
            schema,
        }
    }

    pub fn input_schema(&self) -> Schema {
        (self.schema)()
    }

    pub fn for_input<Input>() -> Self
    where
        Input: ToolInput,
    {
        Self::new(Input::NAME, Input::DESCRIPTION, || schema_for!(Input))
    }
}

#[derive(Debug, Error)]
pub enum ToolCallError {
    #[error("unknown tool `{name}`")]
    UnknownTool { name: String },
    #[error("failed to deserialize tool call for `{name}`: {source}")]
    Deserialize {
        name: String,
        #[source]
        source: serde_json::Error,
    },
}

#[derive(Debug, Error)]
pub enum ToolUseError {
    #[error("tool metadata for `{actual}` does not match expected tool `{expected}`")]
    MismatchedToolName {
        expected: &'static str,
        actual: String,
    },
    #[error("failed to serialize tool output: {0}")]
    Serialize(#[from] serde_json::Error),
}

#[derive(Debug, Error)]
pub enum ToolExecutionError<E> {
    #[error("tool execution failed: {0}")]
    Execute(E),
    #[error("failed to build tool use: {0}")]
    ToolUse(#[from] ToolUseError),
}

pub trait ToolInput:
    Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static
{
    type Output: Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static;

    const NAME: &'static str;
    const DESCRIPTION: &'static str;

    fn tool_use(metadata: ToolMetadata, output: Self::Output) -> Result<ToolUse, ToolUseError> {
        if metadata.name.as_str() != Self::NAME {
            return Err(ToolUseError::MismatchedToolName {
                expected: Self::NAME,
                actual: metadata.name.as_str().to_string(),
            });
        }
        let result = crate::conversation::RawJson::from_serializable(&output)?;
        Ok(metadata.into_tool_use(result))
    }
}

pub trait ToolCallWrapper {
    fn metadata(&self) -> &ToolMetadata;
}

impl ToolCallWrapper for std::convert::Infallible {
    fn metadata(&self) -> &ToolMetadata {
        match *self {}
    }
}

pub trait ToolSelector<T: ?Sized>:
    Copy
    + Clone
    + fmt::Debug
    + Eq
    + PartialEq
    + std::hash::Hash
    + Serialize
    + DeserializeOwned
    + JsonSchema
    + Send
    + Sync
    + 'static
{
    fn name(self) -> &'static str;

    fn definition(self) -> &'static ToolDef;

    fn all() -> &'static [Self];

    fn try_from_name(name: &str) -> Option<Self>;
}

pub trait Toolset: Send + Sync + 'static {
    type ToolCall: ToolCallWrapper + Clone + fmt::Debug + Eq + PartialEq + Send + Sync + 'static;
    type Selector: ToolSelector<Self>;

    fn definitions() -> &'static [ToolDef];

    fn definitions_for<I>(selectors: I) -> Vec<&'static ToolDef>
    where
        I: IntoIterator<Item = Self::Selector>,
    {
        selectors
            .into_iter()
            .map(|selector| selector.definition())
            .collect()
    }

    fn parse_tool_call(metadata: ToolMetadata) -> Result<Self::ToolCall, ToolCallError>;
}

#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub enum ToolPolicy<T: Toolset> {
    #[default]
    Disabled,
    AllowAll,
    AllowOnly(Vec<T::Selector>),
    RequireAll,
    RequireOnly(Vec<T::Selector>),
}

impl<T> ToolPolicy<T>
where
    T: Toolset,
{
    pub fn allow_only(selectors: impl IntoIterator<Item = T::Selector>) -> Self {
        let selectors = selectors.into_iter().collect::<Vec<_>>();
        if selectors.is_empty() {
            Self::Disabled
        } else {
            Self::AllowOnly(selectors)
        }
    }

    pub fn require_only(selectors: impl IntoIterator<Item = T::Selector>) -> Self {
        let selectors = selectors.into_iter().collect::<Vec<_>>();
        if selectors.is_empty() {
            Self::Disabled
        } else {
            Self::RequireOnly(selectors)
        }
    }

    pub fn uses_tools(&self) -> bool {
        !matches!(self, Self::Disabled)
    }

    pub fn requires_tools(&self) -> bool {
        matches!(self, Self::RequireAll | Self::RequireOnly(_))
    }

    pub fn selected(&self) -> Option<&[T::Selector]> {
        match self {
            Self::AllowOnly(selectors) | Self::RequireOnly(selectors) => Some(selectors.as_slice()),
            _ => None,
        }
    }
}

#[derive(
    Clone,
    Copy,
    Debug,
    Eq,
    PartialEq,
    Hash,
    serde::Serialize,
    serde::Deserialize,
    schemars::JsonSchema,
)]
pub enum NoToolSelector {}

impl ToolSelector<NoTools> for NoToolSelector {
    fn name(self) -> &'static str {
        match self {}
    }

    fn definition(self) -> &'static ToolDef {
        match self {}
    }

    fn all() -> &'static [Self] {
        &[]
    }

    fn try_from_name(_name: &str) -> Option<Self> {
        None
    }
}

#[derive(Clone, Copy, Debug, Default)]
pub struct NoTools;

impl Toolset for NoTools {
    type ToolCall = std::convert::Infallible;
    type Selector = NoToolSelector;

    fn definitions() -> &'static [ToolDef] {
        &[]
    }

    fn parse_tool_call(metadata: ToolMetadata) -> Result<Self::ToolCall, ToolCallError> {
        Err(ToolCallError::UnknownTool {
            name: metadata.name.as_str().to_string(),
        })
    }
}