1use serde::{de::DeserializeOwned, Deserialize, Serialize};
2use std::{future::Future, pin::Pin, sync::Arc};
3
4use crate::AgentState;
5
6#[derive(Clone, Debug, Serialize, Deserialize)]
7pub struct EmptyParams {}
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct ActionParameter {
11 pub name: String,
12 pub description: String,
13 #[serde(rename = "type")]
14 pub param_type: String,
15 pub required: bool,
16}
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct ActionDefinition {
20 pub name: String,
21 pub description: String,
22 pub parameters: Vec<ActionParameter>,
23}
24
25pub type Handler<S> = Box<
26 dyn Fn(
27 serde_json::Value,
28 serde_json::Value,
29 AgentState<S>,
30 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>
31 + Send
32 + Sync,
33>;
34pub type ConfirmHandler<S> = Arc<
35 Box<
36 dyn Fn(
37 serde_json::Value,
38 serde_json::Value,
39 AgentState<S>,
40 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>
41 + Send
42 + Sync,
43 >,
44>;
45pub struct FunctionAction<S: Send + Sync + Clone + 'static> {
46 definition: ActionDefinition,
47 handler: Handler<S>,
48 pub confirm_handler: Option<ConfirmHandler<S>>,
49}
50
51impl<S: Send + Sync + Clone + 'static> FunctionAction<S> {
52 pub fn definition(&self) -> ActionDefinition {
53 self.definition.clone()
54 }
55
56 pub fn execute(
57 &self,
58 params: serde_json::Value,
59 send_state: serde_json::Value,
60 state: AgentState<S>,
61 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>> {
62 (self.handler)(params, send_state, state)
63 }
64
65 pub fn confirm(
66 &self,
67 params: serde_json::Value,
68 send_state: serde_json::Value,
69 state: AgentState<S>,
70 ) -> Option<Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>> {
71 self.confirm_handler
72 .as_ref()
73 .map(|handler| handler(params, send_state, state))
74 }
75}
76
77pub type EmptyConfirmHandler<T, V, S> =
78 fn(T, V, AgentState<S>) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>;
79pub struct ActionBuilder<F, P, V, S, Q = String, CF = EmptyConfirmHandler<Q, V, S>> {
86 name: String,
87 description: String,
88 parameters: Vec<ActionParameter>,
89 handler: F,
90 confirm_handler: Option<CF>,
91 _phantom_handler_input: std::marker::PhantomData<P>,
92 _phantom_confirm_handler_input: std::marker::PhantomData<Q>,
93 _phantom_state: std::marker::PhantomData<S>,
94 _phantom_send_state: std::marker::PhantomData<V>,
95}
96
97impl<F, CF, P, Q, S, V, Fut, CFut> ActionBuilder<F, P, V, S, Q, CF>
98where
99 F: Fn(P, V, AgentState<S>) -> Fut + Send + Sync + Clone + 'static,
101 Fut: Future<Output = Result<Q, String>> + Send + Sync + 'static,
102 P: DeserializeOwned + Send + 'static,
103 CF: Fn(Q, V, AgentState<S>) -> CFut + Send + Sync + Clone + 'static,
105 CFut: Future<Output = Result<String, String>> + Send + Sync + 'static,
106 Q: Serialize + DeserializeOwned + Send + 'static,
107 S: Send + Sync + Clone + 'static,
108 V: Serialize + DeserializeOwned + Send + 'static,
109{
110 pub fn new(name: impl Into<String>, handler: F, confirm_handler: Option<CF>) -> Self {
111 Self {
112 name: name.into(),
113 description: String::new(),
114 parameters: Vec::new(),
115 handler,
116 confirm_handler,
117 _phantom_handler_input: std::marker::PhantomData,
118 _phantom_confirm_handler_input: std::marker::PhantomData,
119 _phantom_state: std::marker::PhantomData,
120 _phantom_send_state: std::marker::PhantomData,
121 }
122 }
123
124 pub fn description(mut self, description: impl Into<String>) -> Self {
125 self.description = description.into();
126 self
127 }
128
129 pub fn parameter(
130 mut self,
131 name: impl Into<String>,
132 description: impl Into<String>,
133 param_type: impl Into<String>,
134 required: bool,
135 ) -> Self {
136 self.parameters.push(ActionParameter {
137 name: name.into(),
138 description: description.into(),
139 param_type: param_type.into(),
140 required,
141 });
142 self
143 }
144
145 pub fn build(self) -> FunctionAction<S> {
146 let handler = self.handler;
147 FunctionAction {
148 definition: ActionDefinition {
149 name: self.name,
150 description: self.description,
151 parameters: self.parameters,
152 },
153 handler: Box::new(
154 move |params: serde_json::Value,
155 send_state: serde_json::Value,
156 state: AgentState<S>| {
157 let handler = handler.clone();
158 Box::pin(async move {
159 let params = serde_json::from_value(params)
160 .map_err(|e| format!("Invalid parameters: {}", e))?;
161 let send_state = serde_json::from_value(send_state)
162 .map_err(|e| format!("Invalid send_state: {}", e))?;
163 let result = handler(params, send_state, state).await?;
164 serde_json::to_string(&result)
166 .map_err(|e| format!("Failed to serialize result: {}", e))
167 })
168 },
169 ),
170 confirm_handler: self.confirm_handler.map(|handler| {
171 Arc::new(Box::new(
172 move |params: serde_json::Value,
173 send_state: serde_json::Value,
174 state: AgentState<S>| {
175 let handler = handler.clone();
176 let fut: Pin<
177 Box<dyn Future<Output = Result<String, String>> + Send + Sync>,
178 > = Box::pin(async move {
179 let params = serde_json::from_value::<Q>(params)
180 .map_err(|e| format!("Invalid parameters: {}", e))?;
181 let send_state = serde_json::from_value(send_state)
182 .map_err(|e| format!("Invalid send_state: {}", e))?;
183 handler(params, send_state, state).await
184 });
185 fut
186 },
187 ) as Box<dyn Fn(_, _, _) -> _ + Send + Sync>)
188 }),
189 }
190 }
191}
192
193pub trait ActionGroup<S: Send + Sync + Clone + 'static> {
195 fn actions(&self) -> &[Arc<FunctionAction<S>>];
196}
197
198#[cfg(test)]
199mod tests {
200 use std::sync::Arc;
201
202 use tokio::sync::Mutex;
203
204 use super::*;
205
206 #[derive(Debug, Deserialize)]
208 struct WeatherParams {
209 location: String,
210 #[serde(default)]
211 units: Option<String>,
212 }
213
214 #[tokio::test]
215 async fn test_typed_function_action() {
216 async fn weather(
217 params: WeatherParams,
218 _send_state: serde_json::Value,
219 _state: AgentState<()>,
220 ) -> Result<String, String> {
221 println!("Executing action in function: {:?}", "get_weather");
222 let units = params.units.unwrap_or_else(|| "celsius".to_string());
223 Ok(format!("Weather in {} ({}): Sunny", params.location, units))
224 }
225
226 let action = ActionBuilder::<_, _, _, _>::new("get_weather", weather, None)
227 .description("Get the weather for a location")
228 .parameter("location", "The city to get weather for", "string", true)
229 .parameter(
230 "units",
231 "Temperature units (celsius/fahrenheit)",
232 "string",
233 false,
234 )
235 .build();
236
237 let def = action.definition();
239 assert_eq!(def.name, "get_weather");
240 assert_eq!(def.parameters.len(), 2);
241
242 let state = Arc::new(Mutex::new(()));
244 let send_state = serde_json::json!({});
245 let params = serde_json::json!({
246 "location": "London",
247 "units": "fahrenheit"
248 });
249 let result = action
250 .execute(params, send_state.clone(), state.clone())
251 .await
252 .unwrap();
253
254 let expected = serde_json::to_string(&"Weather in London (fahrenheit): Sunny").unwrap();
256 assert_eq!(result, expected);
257
258 let params = serde_json::json!({
260 "location": "Paris"
261 });
262 let result = action
263 .execute(params, send_state.clone(), state.clone())
264 .await
265 .unwrap();
266 let expected = serde_json::to_string(&"Weather in Paris (celsius): Sunny").unwrap();
267 assert_eq!(result, expected);
268
269 let params = serde_json::json!({
271 "wrong_field": "London"
272 });
273 let result = action.execute(params, send_state, state.clone()).await;
274 assert!(result.is_err());
275 }
276}