chatgpt/functions/
types.rs1use crate::functions::{CallableAsyncFunction, FunctionArgument};
2use async_trait::async_trait;
3use schemars::schema_for;
4use serde::ser::SerializeStruct;
5use serde::{Deserialize, Serialize, Serializer};
6use serde_json::Value;
7use std::marker::PhantomData;
8
9#[derive(Debug, Clone)]
11pub struct FunctionDescriptor<A: FunctionArgument> {
12 pub name: &'static str,
14 pub description: &'static str,
16 pub parameters: PhantomData<A>,
18}
19
20impl<A: FunctionArgument> Serialize for FunctionDescriptor<A> {
21 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
22 where
23 S: Serializer,
24 {
25 let mut s = serializer.serialize_struct("FunctionDescriptor", 3)?;
26 s.serialize_field("name", self.name)?;
27 s.serialize_field("description", self.description)?;
28 let mut schema = schema_for!(A);
29 schema.meta_schema = None; s.serialize_field("parameters", &schema)?;
31
32 s.end()
33 }
34}
35
36#[async_trait]
38pub trait GptFunctionHolder: Send + Sync {
39 async fn try_invoke(&self, args: &str) -> crate::Result<serde_json::Value>;
41}
42
43#[derive(Debug, Clone)]
45pub struct GptFunction<A: FunctionArgument, C: CallableAsyncFunction<A>>
46where
47 A: Send + Sync,
48 C: Send + Sync,
49{
50 pub descriptor: FunctionDescriptor<A>,
52 pub callable: PhantomData<C>,
54}
55
56#[async_trait]
57impl<A: FunctionArgument + Send + Sync, C: CallableAsyncFunction<A> + Send + Sync> GptFunctionHolder
58 for GptFunction<A, C>
59{
60 async fn try_invoke(&self, args: &str) -> crate::Result<Value> {
61 let args_value: A = serde_json::from_str(args).map_err(crate::err::Error::from)?;
62 C::invoke(args_value).await
63 }
64}
65
66#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Serialize)]
68pub enum FunctionCallingMode {
69 Auto,
71 None,
73}
74
75#[derive(Serialize, Debug, Copy, Clone, Default, PartialOrd, PartialEq)]
77pub enum FunctionValidationStrategy {
78 Strict,
80 #[default]
82 Loose,
83}
84
85#[derive(Debug, Clone, PartialOrd, PartialEq, Serialize, Deserialize)]
87pub struct FunctionCall {
88 pub name: String,
90 pub arguments: String,
92}