ferrox_actions/
action.rs

1use serde::{de::DeserializeOwned, Deserialize, Serialize};
2use std::{future::Future, pin::Pin, sync::Arc};
3
4use crate::AgentState;
5
6#[derive(Clone, Debug, Serialize, Deserialize)]
7pub struct EmptyParams {}
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct ActionParameter {
11    pub name: String,
12    pub description: String,
13    #[serde(rename = "type")]
14    pub param_type: String,
15    pub required: bool,
16}
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct ActionDefinition {
20    pub name: String,
21    pub description: String,
22    pub parameters: Vec<ActionParameter>,
23}
24
25pub type Handler<S> = Box<
26    dyn Fn(
27            serde_json::Value,
28            serde_json::Value,
29            AgentState<S>,
30        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>
31        + Send
32        + Sync,
33>;
34pub type ConfirmHandler<S> = Arc<
35    Box<
36        dyn Fn(
37                serde_json::Value,
38                serde_json::Value,
39                AgentState<S>,
40            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>
41            + Send
42            + Sync,
43    >,
44>;
45pub struct FunctionAction<S: Send + Sync + Clone + 'static> {
46    definition: ActionDefinition,
47    handler: Handler<S>,
48    pub confirm_handler: Option<ConfirmHandler<S>>,
49}
50
51impl<S: Send + Sync + Clone + 'static> FunctionAction<S> {
52    pub fn definition(&self) -> ActionDefinition {
53        self.definition.clone()
54    }
55
56    pub fn execute(
57        &self,
58        params: serde_json::Value,
59        send_state: serde_json::Value,
60        state: AgentState<S>,
61    ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>> {
62        (self.handler)(params, send_state, state)
63    }
64
65    pub fn confirm(
66        &self,
67        params: serde_json::Value,
68        send_state: serde_json::Value,
69        state: AgentState<S>,
70    ) -> Option<Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>> {
71        self.confirm_handler
72            .as_ref()
73            .map(|handler| handler(params, send_state, state))
74    }
75}
76
77pub type EmptyConfirmHandler<T, V, S> =
78    fn(T, V, AgentState<S>) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + Sync>>;
79/// A builder to create actions from async functions with typed parameters
80/// F: The handler function type
81/// P: Input parameters for the handler
82/// S: State type
83/// Q: Output type from handler and input type for confirm handler
84/// CF: Confirm handler function type (defaults to EmptyConfirmHandler)
85pub struct ActionBuilder<F, P, V, S, Q = String, CF = EmptyConfirmHandler<Q, V, S>> {
86    name: String,
87    description: String,
88    parameters: Vec<ActionParameter>,
89    handler: F,
90    confirm_handler: Option<CF>,
91    _phantom_handler_input: std::marker::PhantomData<P>,
92    _phantom_confirm_handler_input: std::marker::PhantomData<Q>,
93    _phantom_state: std::marker::PhantomData<S>,
94    _phantom_send_state: std::marker::PhantomData<V>,
95}
96
97impl<F, CF, P, Q, S, V, Fut, CFut> ActionBuilder<F, P, V, S, Q, CF>
98where
99    // Handler F takes P and returns Q
100    F: Fn(P, V, AgentState<S>) -> Fut + Send + Sync + Clone + 'static,
101    Fut: Future<Output = Result<Q, String>> + Send + Sync + 'static,
102    P: DeserializeOwned + Send + 'static,
103    // Confirm handler CF takes Q and returns String
104    CF: Fn(Q, V, AgentState<S>) -> CFut + Send + Sync + Clone + 'static,
105    CFut: Future<Output = Result<String, String>> + Send + Sync + 'static,
106    Q: Serialize + DeserializeOwned + Send + 'static,
107    S: Send + Sync + Clone + 'static,
108    V: Serialize + DeserializeOwned + Send + 'static,
109{
110    pub fn new(name: impl Into<String>, handler: F, confirm_handler: Option<CF>) -> Self {
111        Self {
112            name: name.into(),
113            description: String::new(),
114            parameters: Vec::new(),
115            handler,
116            confirm_handler,
117            _phantom_handler_input: std::marker::PhantomData,
118            _phantom_confirm_handler_input: std::marker::PhantomData,
119            _phantom_state: std::marker::PhantomData,
120            _phantom_send_state: std::marker::PhantomData,
121        }
122    }
123
124    pub fn description(mut self, description: impl Into<String>) -> Self {
125        self.description = description.into();
126        self
127    }
128
129    pub fn parameter(
130        mut self,
131        name: impl Into<String>,
132        description: impl Into<String>,
133        param_type: impl Into<String>,
134        required: bool,
135    ) -> Self {
136        self.parameters.push(ActionParameter {
137            name: name.into(),
138            description: description.into(),
139            param_type: param_type.into(),
140            required,
141        });
142        self
143    }
144
145    pub fn build(self) -> FunctionAction<S> {
146        let handler = self.handler;
147        FunctionAction {
148            definition: ActionDefinition {
149                name: self.name,
150                description: self.description,
151                parameters: self.parameters,
152            },
153            handler: Box::new(
154                move |params: serde_json::Value,
155                      send_state: serde_json::Value,
156                      state: AgentState<S>| {
157                    let handler = handler.clone();
158                    Box::pin(async move {
159                        let params = serde_json::from_value(params)
160                            .map_err(|e| format!("Invalid parameters: {}", e))?;
161                        let send_state = serde_json::from_value(send_state)
162                            .map_err(|e| format!("Invalid send_state: {}", e))?;
163                        let result = handler(params, send_state, state).await?;
164                        // If there's a confirm handler, mark this as a preview
165                        serde_json::to_string(&result)
166                            .map_err(|e| format!("Failed to serialize result: {}", e))
167                    })
168                },
169            ),
170            confirm_handler: self.confirm_handler.map(|handler| {
171                Arc::new(Box::new(
172                    move |params: serde_json::Value,
173                          send_state: serde_json::Value,
174                          state: AgentState<S>| {
175                        let handler = handler.clone();
176                        let fut: Pin<
177                            Box<dyn Future<Output = Result<String, String>> + Send + Sync>,
178                        > = Box::pin(async move {
179                            let params = serde_json::from_value::<Q>(params)
180                                .map_err(|e| format!("Invalid parameters: {}", e))?;
181                            let send_state = serde_json::from_value(send_state)
182                                .map_err(|e| format!("Invalid send_state: {}", e))?;
183                            handler(params, send_state, state).await
184                        });
185                        fut
186                    },
187                ) as Box<dyn Fn(_, _, _) -> _ + Send + Sync>)
188            }),
189        }
190    }
191}
192
193/// Represents a group of related actions
194pub trait ActionGroup<S: Send + Sync + Clone + 'static> {
195    fn actions(&self) -> &[Arc<FunctionAction<S>>];
196}
197
198#[cfg(test)]
199mod tests {
200    use std::sync::Arc;
201
202    use tokio::sync::Mutex;
203
204    use super::*;
205
206    // Define a strongly typed parameter struct
207    #[derive(Debug, Deserialize)]
208    struct WeatherParams {
209        location: String,
210        #[serde(default)]
211        units: Option<String>,
212    }
213
214    #[tokio::test]
215    async fn test_typed_function_action() {
216        async fn weather(
217            params: WeatherParams,
218            _send_state: serde_json::Value,
219            _state: AgentState<()>,
220        ) -> Result<String, String> {
221            println!("Executing action in function: {:?}", "get_weather");
222            let units = params.units.unwrap_or_else(|| "celsius".to_string());
223            Ok(format!("Weather in {} ({}): Sunny", params.location, units))
224        }
225
226        let action = ActionBuilder::<_, _, _, _>::new("get_weather", weather, None)
227            .description("Get the weather for a location")
228            .parameter("location", "The city to get weather for", "string", true)
229            .parameter(
230                "units",
231                "Temperature units (celsius/fahrenheit)",
232                "string",
233                false,
234            )
235            .build();
236
237        // Test the definition
238        let def = action.definition();
239        assert_eq!(def.name, "get_weather");
240        assert_eq!(def.parameters.len(), 2);
241
242        // Test execution with all parameters
243        let state = Arc::new(Mutex::new(()));
244        let send_state = serde_json::json!({});
245        let params = serde_json::json!({
246            "location": "London",
247            "units": "fahrenheit"
248        });
249        let result = action
250            .execute(params, send_state.clone(), state.clone())
251            .await
252            .unwrap();
253
254        // The result is JSON-serialized, so we need to deserialize it for comparison
255        let expected = serde_json::to_string(&"Weather in London (fahrenheit): Sunny").unwrap();
256        assert_eq!(result, expected);
257
258        // Test execution with only required parameters
259        let params = serde_json::json!({
260            "location": "Paris"
261        });
262        let result = action
263            .execute(params, send_state.clone(), state.clone())
264            .await
265            .unwrap();
266        let expected = serde_json::to_string(&"Weather in Paris (celsius): Sunny").unwrap();
267        assert_eq!(result, expected);
268
269        // Test execution with invalid parameters
270        let params = serde_json::json!({
271            "wrong_field": "London"
272        });
273        let result = action.execute(params, send_state, state.clone()).await;
274        assert!(result.is_err());
275    }
276}