1use crate::client::ComposioClient;
77use crate::error::ComposioError;
78use crate::models::response::ToolProxyResponse;
79use crate::models::tools::{ProxyParameter, ToolInfo, ToolkitRef};
80use async_trait::async_trait;
81use serde_json::Value as JsonValue;
82use std::collections::HashMap;
83use std::sync::Arc;
84
85#[async_trait]
87pub trait ExecuteRequestFn: Send + Sync {
88 async fn execute(
97 &self,
98 endpoint: &str,
99 method: &str,
100 body: Option<JsonValue>,
101 connected_account_id: Option<&str>,
102 parameters: Option<Vec<ProxyParameter>>,
103 ) -> Result<ToolProxyResponse, ComposioError>;
104}
105
106#[async_trait]
112pub trait CustomToolExecutor: Send + Sync {
113 async fn execute(
120 &self,
121 request: JsonValue,
122 execute_request: Option<&dyn ExecuteRequestFn>,
123 auth_credentials: Option<&HashMap<String, JsonValue>>,
124 ) -> Result<JsonValue, ComposioError>;
125}
126
127struct SimpleExecutor<F>
129where
130 F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync,
131{
132 func: F,
133}
134
135#[async_trait]
136impl<F> CustomToolExecutor for SimpleExecutor<F>
137where
138 F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync,
139{
140 async fn execute(
141 &self,
142 request: JsonValue,
143 _execute_request: Option<&dyn ExecuteRequestFn>,
144 _auth_credentials: Option<&HashMap<String, JsonValue>>,
145 ) -> Result<JsonValue, ComposioError> {
146 (self.func)(request)
147 }
148}
149
150struct AuthenticatedExecutor<F>
152where
153 F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
154 + Send
155 + Sync,
156{
157 func: F,
158}
159
160#[async_trait]
161impl<F> CustomToolExecutor for AuthenticatedExecutor<F>
162where
163 F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
164 + Send
165 + Sync,
166{
167 async fn execute(
168 &self,
169 request: JsonValue,
170 execute_request: Option<&dyn ExecuteRequestFn>,
171 auth_credentials: Option<&HashMap<String, JsonValue>>,
172 ) -> Result<JsonValue, ComposioError> {
173 let execute_request = execute_request
174 .ok_or_else(|| ComposioError::InvalidInput("Execute request function required".to_string()))?;
175 let auth_credentials = auth_credentials
176 .ok_or_else(|| ComposioError::InvalidInput("Auth credentials required".to_string()))?;
177
178 (self.func)(request, execute_request, auth_credentials)
179 }
180}
181
182pub struct CustomTool {
184 pub slug: String,
186
187 pub name: String,
189
190 pub description: String,
192
193 pub toolkit: Option<String>,
195
196 pub input_schema: JsonValue,
198
199 pub output_schema: Option<JsonValue>,
201
202 pub requires_auth: bool,
204
205 executor: Box<dyn CustomToolExecutor>,
207
208 client: Arc<ComposioClient>,
210}
211
212impl CustomTool {
213 pub fn new_simple<F>(
222 name: &str,
223 description: &str,
224 input_schema: JsonValue,
225 executor: F,
226 client: Arc<ComposioClient>,
227 ) -> Self
228 where
229 F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync + 'static,
230 {
231 let slug = name.to_uppercase().replace(' ', "_");
232
233 Self {
234 slug,
235 name: name.to_string(),
236 description: description.to_string(),
237 toolkit: None,
238 input_schema,
239 output_schema: None,
240 requires_auth: false,
241 executor: Box::new(SimpleExecutor { func: executor }),
242 client,
243 }
244 }
245
246 pub fn new_with_auth<F>(
256 name: &str,
257 description: &str,
258 toolkit: &str,
259 input_schema: JsonValue,
260 executor: F,
261 client: Arc<ComposioClient>,
262 ) -> Self
263 where
264 F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
265 + Send
266 + Sync
267 + 'static,
268 {
269 let toolkit_upper = toolkit.to_uppercase();
270 let name_upper = name.to_uppercase().replace(' ', "_");
271 let slug = format!("{}_{}", toolkit_upper, name_upper);
272 let full_name = format!("{}_{}", toolkit.to_lowercase(), name);
273
274 Self {
275 slug,
276 name: full_name,
277 description: description.to_string(),
278 toolkit: Some(toolkit.to_string()),
279 input_schema,
280 output_schema: None,
281 requires_auth: true,
282 executor: Box::new(AuthenticatedExecutor { func: executor }),
283 client,
284 }
285 }
286
287 pub async fn execute(
293 &self,
294 arguments: HashMap<String, JsonValue>,
295 user_id: Option<&str>,
296 ) -> Result<JsonValue, ComposioError> {
297 let request = JsonValue::Object(
298 arguments.into_iter()
299 .map(|(k, v)| (k, v))
300 .collect()
301 );
302
303 if self.requires_auth {
304 let user_id = user_id.ok_or_else(|| {
305 ComposioError::InvalidInput("user_id required for authenticated tools".to_string())
306 })?;
307
308 let auth_credentials = self.get_auth_credentials(user_id).await?;
309
310 let proxy_executor = ProxyExecutor {
312 client: self.client.clone(),
313 toolkit: self.toolkit.clone().unwrap(),
314 };
315
316 self.executor.execute(
317 request,
318 Some(&proxy_executor),
319 Some(&auth_credentials),
320 ).await
321 } else {
322 self.executor.execute(request, None, None).await
323 }
324 }
325
326 async fn get_auth_credentials(&self, user_id: &str) -> Result<HashMap<String, JsonValue>, ComposioError> {
328 let toolkit = self.toolkit.as_ref()
329 .ok_or_else(|| ComposioError::InvalidInput("Toolkit required for auth".to_string()))?;
330
331 let params = crate::models::connected_accounts::ConnectedAccountListParams {
333 user_ids: Some(vec![user_id.to_string()]),
334 toolkit_slugs: Some(vec![toolkit.clone()]),
335 statuses: Some(vec![crate::models::connected_accounts::ConnectionStatus::Active]),
336 ..Default::default()
337 };
338
339 let accounts = self.client.list_connected_accounts(params).await?;
340
341 if accounts.items.is_empty() {
342 return Err(ComposioError::ValidationError(format!(
343 "No active connected accounts found for toolkit {} and user {}",
344 toolkit, user_id
345 )));
346 }
347
348 let account = accounts.items.into_iter()
350 .max_by(|a, b| a.created_at.cmp(&b.created_at))
351 .unwrap();
352
353 if let Some(state) = account.state {
355 Ok(serde_json::from_value(state)?)
356 } else {
357 Err(ComposioError::ValidationError(
358 "Connected account has no state data".to_string()
359 ))
360 }
361 }
362
363 pub fn to_tool_info(&self) -> ToolInfo {
365 ToolInfo {
366 slug: self.slug.clone(),
367 name: self.name.clone(),
368 description: self.description.clone(),
369 input_parameters: self.input_schema.clone(),
370 output_parameters: self.output_schema.clone().unwrap_or(JsonValue::Object(Default::default())),
371 scopes: vec![],
372 version: "1.0.0".to_string(),
373 available_versions: vec![],
374 toolkit: ToolkitRef {
375 slug: self.toolkit.clone().unwrap_or_else(|| "custom".to_string()).to_uppercase(),
376 name: Some(self.toolkit.clone().unwrap_or_else(|| "custom".to_string())),
377 logo: None,
378 },
379 is_deprecated: false,
380 no_auth: !self.requires_auth,
381 tags: vec![],
382 }
383 }
384}
385
386struct ProxyExecutor {
388 #[allow(dead_code)]
389 client: Arc<ComposioClient>,
390 #[allow(dead_code)]
391 toolkit: String,
392}
393
394#[async_trait]
395impl ExecuteRequestFn for ProxyExecutor {
396 async fn execute(
397 &self,
398 _endpoint: &str,
399 _method: &str,
400 _body: Option<JsonValue>,
401 _connected_account_id: Option<&str>,
402 _parameters: Option<Vec<ProxyParameter>>,
403 ) -> Result<ToolProxyResponse, ComposioError> {
404 Err(ComposioError::InvalidInput(
407 "Proxy execution not yet fully implemented - requires proxy API endpoint".to_string()
408 ))
409 }
410}
411
412pub struct CustomToolsRegistry {
414 tools: HashMap<String, Arc<CustomTool>>,
415 client: Arc<ComposioClient>,
416}
417
418impl CustomToolsRegistry {
419 pub fn new(client: Arc<ComposioClient>) -> Self {
421 Self {
422 tools: HashMap::new(),
423 client,
424 }
425 }
426
427 pub fn register_simple<F>(
438 &mut self,
439 name: &str,
440 description: &str,
441 input_schema: JsonValue,
442 executor: F,
443 ) -> Arc<CustomTool>
444 where
445 F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync + 'static,
446 {
447 let tool = Arc::new(CustomTool::new_simple(
448 name,
449 description,
450 input_schema,
451 executor,
452 self.client.clone(),
453 ));
454
455 self.tools.insert(tool.slug.clone(), tool.clone());
456 tool
457 }
458
459 pub fn register_with_auth<F>(
471 &mut self,
472 name: &str,
473 description: &str,
474 toolkit: &str,
475 input_schema: JsonValue,
476 executor: F,
477 ) -> Arc<CustomTool>
478 where
479 F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
480 + Send
481 + Sync
482 + 'static,
483 {
484 let tool = Arc::new(CustomTool::new_with_auth(
485 name,
486 description,
487 toolkit,
488 input_schema,
489 executor,
490 self.client.clone(),
491 ));
492
493 self.tools.insert(tool.slug.clone(), tool.clone());
494 tool
495 }
496
497 pub fn get(&self, slug: &str) -> Option<Arc<CustomTool>> {
499 self.tools.get(slug).cloned()
500 }
501
502 pub async fn execute(
509 &self,
510 slug: &str,
511 arguments: HashMap<String, JsonValue>,
512 user_id: Option<&str>,
513 ) -> Result<JsonValue, ComposioError> {
514 let tool = self.get(slug)
515 .ok_or_else(|| ComposioError::ValidationError(format!("Custom tool {} not found", slug)))?;
516
517 tool.execute(arguments, user_id).await
518 }
519
520 pub fn list(&self) -> Vec<Arc<CustomTool>> {
522 self.tools.values().cloned().collect()
523 }
524
525 pub fn list_as_tools(&self) -> Vec<ToolInfo> {
527 self.tools.values()
528 .map(|tool| tool.to_tool_info())
529 .collect()
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use serde_json::json;
537
538 #[test]
539 fn test_custom_tool_simple() {
540 let client = Arc::new(
541 ComposioClient::builder()
542 .api_key("test_key")
543 .build()
544 .unwrap()
545 );
546
547 let tool = CustomTool::new_simple(
548 "calculate sum",
549 "Calculate the sum of two numbers",
550 json!({
551 "type": "object",
552 "properties": {
553 "a": {"type": "number"},
554 "b": {"type": "number"}
555 }
556 }),
557 |request| {
558 let a = request["a"].as_f64().unwrap_or(0.0);
559 let b = request["b"].as_f64().unwrap_or(0.0);
560 Ok(json!({"result": a + b}))
561 },
562 client,
563 );
564
565 assert_eq!(tool.slug, "CALCULATE_SUM");
566 assert_eq!(tool.name, "calculate sum");
567 assert!(!tool.requires_auth);
568 assert!(tool.toolkit.is_none());
569 }
570
571 #[test]
572 fn test_custom_tool_with_auth() {
573 let client = Arc::new(
574 ComposioClient::builder()
575 .api_key("test_key")
576 .build()
577 .unwrap()
578 );
579
580 let tool = CustomTool::new_with_auth(
581 "create issue",
582 "Create a GitHub issue",
583 "github",
584 json!({
585 "type": "object",
586 "properties": {
587 "title": {"type": "string"}
588 }
589 }),
590 |_request, _execute_request, _auth_credentials| {
591 Ok(json!({"id": 123}))
592 },
593 client,
594 );
595
596 assert_eq!(tool.slug, "GITHUB_CREATE_ISSUE");
597 assert_eq!(tool.name, "github_create issue");
598 assert!(tool.requires_auth);
599 assert_eq!(tool.toolkit, Some("github".to_string()));
600 }
601
602 #[test]
603 fn test_registry() {
604 let client = Arc::new(
605 ComposioClient::builder()
606 .api_key("test_key")
607 .build()
608 .unwrap()
609 );
610
611 let mut registry = CustomToolsRegistry::new(client);
612
613 registry.register_simple(
614 "test_tool",
615 "A test tool",
616 json!({"type": "object"}),
617 |_request| Ok(json!({"success": true}))
618 );
619
620 assert!(registry.get("TEST_TOOL").is_some());
621 assert_eq!(registry.list().len(), 1);
622 }
623}