1use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42
43use super::enums::{MetaToolSlug, TagType};
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SessionConfig {
84 pub user_id: String,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub toolkits: Option<ToolkitFilter>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub auth_configs: Option<HashMap<String, String>>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub connected_accounts: Option<HashMap<String, String>>,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub manage_connections: Option<ManageConnectionsConfig>,
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub tools: Option<ToolsConfig>,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub tags: Option<TagsConfig>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub workbench: Option<WorkbenchConfig>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
126#[serde(untagged)]
127pub enum ManageConnectionsConfig {
128 Bool(bool),
130 Detailed {
132 enabled: bool,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 enable_wait_for_connections: Option<bool>,
135 },
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
162#[serde(untagged)]
163pub enum ToolkitFilter {
164 Enable(Vec<String>),
165 Disable { disable: Vec<String> },
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct ToolsConfig(pub HashMap<String, ToolFilter>);
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175#[serde(untagged)]
176pub enum ToolFilter {
177 Enable { enable: Vec<String> },
179 Disable { disable: Vec<String> },
181 EnableList(Vec<String>),
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct TagsConfig {
189 #[serde(skip_serializing_if = "Option::is_none")]
191 pub enabled: Option<Vec<TagType>>,
192 #[serde(skip_serializing_if = "Option::is_none")]
194 pub disabled: Option<Vec<TagType>>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct WorkbenchConfig {
200 #[serde(skip_serializing_if = "Option::is_none")]
201 #[serde(alias = "proxy_execution_enabled")]
202 pub proxy_execution: Option<bool>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub auto_offload_threshold: Option<u32>,
205}
206
207#[derive(Debug, Clone, Serialize)]
209pub struct ToolExecutionRequest {
210 pub tool_slug: String,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 pub arguments: Option<serde_json::Value>,
213}
214
215#[derive(Debug, Clone, Serialize)]
217pub struct MetaToolExecutionRequest {
218 pub slug: MetaToolSlug,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub arguments: Option<serde_json::Value>,
221}
222
223#[derive(Debug, Clone, Serialize)]
225pub struct LinkRequest {
226 pub toolkit: String,
227 #[serde(skip_serializing_if = "Option::is_none")]
228 pub callback_url: Option<String>,
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use serde_json;
235
236 #[test]
237 fn test_session_config_minimal_serialization() {
238 let config = SessionConfig {
239 user_id: "user_123".to_string(),
240 toolkits: None,
241 auth_configs: None,
242 connected_accounts: None,
243 manage_connections: None,
244 tools: None,
245 tags: None,
246 workbench: None,
247 };
248
249 let json = serde_json::to_string(&config).unwrap();
250 assert!(json.contains("user_123"));
251 assert!(!json.contains("toolkits"));
252 assert!(!json.contains("auth_configs"));
253 }
254
255 #[test]
256 fn test_session_config_with_toolkits_enable() {
257 let config = SessionConfig {
258 user_id: "user_123".to_string(),
259 toolkits: Some(ToolkitFilter::Enable(vec!["github".to_string(), "gmail".to_string()])),
260 auth_configs: None,
261 connected_accounts: None,
262 manage_connections: None,
263 tools: None,
264 tags: None,
265 workbench: None,
266 };
267
268 let json = serde_json::to_string(&config).unwrap();
269 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
270
271 assert!(parsed["toolkits"].is_array());
272 let toolkits = parsed["toolkits"].as_array().unwrap();
273 assert_eq!(toolkits.len(), 2);
274 }
275
276 #[test]
277 fn test_session_config_with_toolkits_disable() {
278 let config = SessionConfig {
279 user_id: "user_123".to_string(),
280 toolkits: Some(ToolkitFilter::Disable {
281 disable: vec!["exa".to_string(), "firecrawl".to_string()],
282 }),
283 auth_configs: None,
284 connected_accounts: None,
285 manage_connections: None,
286 tools: None,
287 tags: None,
288 workbench: None,
289 };
290
291 let json = serde_json::to_string(&config).unwrap();
292 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
293
294 assert!(parsed["toolkits"].is_object());
295 assert!(parsed["toolkits"]["disable"].is_array());
296 }
297
298 #[test]
299 fn test_session_config_with_auth_configs() {
300 let mut auth_configs = HashMap::new();
301 auth_configs.insert("github".to_string(), "ac_custom".to_string());
302
303 let config = SessionConfig {
304 user_id: "user_123".to_string(),
305 toolkits: None,
306 auth_configs: Some(auth_configs),
307 connected_accounts: None,
308 manage_connections: None,
309 tools: None,
310 tags: None,
311 workbench: None,
312 };
313
314 let json = serde_json::to_string(&config).unwrap();
315 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
316
317 assert_eq!(parsed["auth_configs"]["github"], "ac_custom");
318 }
319
320 #[test]
321 fn test_session_config_with_manage_connections_bool() {
322 let config = SessionConfig {
323 user_id: "user_123".to_string(),
324 toolkits: None,
325 auth_configs: None,
326 connected_accounts: None,
327 manage_connections: Some(ManageConnectionsConfig::Bool(true)),
328 tools: None,
329 tags: None,
330 workbench: None,
331 };
332
333 let json = serde_json::to_string(&config).unwrap();
334 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
335
336 assert_eq!(parsed["manage_connections"], true);
337 }
338
339 #[test]
340 fn test_session_config_with_manage_connections_detailed() {
341 let config = SessionConfig {
342 user_id: "user_123".to_string(),
343 toolkits: None,
344 auth_configs: None,
345 connected_accounts: None,
346 manage_connections: Some(ManageConnectionsConfig::Detailed {
347 enabled: true,
348 enable_wait_for_connections: Some(false),
349 }),
350 tools: None,
351 tags: None,
352 workbench: None,
353 };
354
355 let json = serde_json::to_string(&config).unwrap();
356 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
357
358 assert_eq!(parsed["manage_connections"]["enabled"], true);
359 assert_eq!(parsed["manage_connections"]["enable_wait_for_connections"], false);
360 }
361
362 #[test]
363 fn test_session_config_with_tools() {
364 let mut tools_map = HashMap::new();
365 tools_map.insert(
366 "github".to_string(),
367 ToolFilter::EnableList(vec!["GITHUB_CREATE_ISSUE".to_string()]),
368 );
369
370 let config = SessionConfig {
371 user_id: "user_123".to_string(),
372 toolkits: None,
373 auth_configs: None,
374 connected_accounts: None,
375 manage_connections: None,
376 tools: Some(ToolsConfig(tools_map)),
377 tags: None,
378 workbench: None,
379 };
380
381 let json = serde_json::to_string(&config).unwrap();
382 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
383
384 assert!(parsed["tools"]["github"].is_array());
385 }
386
387 #[test]
388 fn test_session_config_with_tags() {
389 let config = SessionConfig {
390 user_id: "user_123".to_string(),
391 toolkits: None,
392 auth_configs: None,
393 connected_accounts: None,
394 manage_connections: None,
395 tools: None,
396 tags: Some(TagsConfig {
397 enabled: Some(vec![TagType::ReadOnlyHint]),
398 disabled: Some(vec![TagType::DestructiveHint]),
399 }),
400 workbench: None,
401 };
402
403 let json = serde_json::to_string(&config).unwrap();
404 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
405
406 assert!(parsed["tags"]["enabled"].is_array());
407 assert!(parsed["tags"]["disabled"].is_array());
408 }
409
410 #[test]
411 fn test_session_config_with_workbench() {
412 let config = SessionConfig {
413 user_id: "user_123".to_string(),
414 toolkits: None,
415 auth_configs: None,
416 connected_accounts: None,
417 manage_connections: None,
418 tools: None,
419 tags: None,
420 workbench: Some(WorkbenchConfig {
421 proxy_execution: Some(true),
422 auto_offload_threshold: Some(1000),
423 }),
424 };
425
426 let json = serde_json::to_string(&config).unwrap();
427 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
428
429 assert_eq!(parsed["workbench"]["proxy_execution"], true);
430 assert_eq!(parsed["workbench"]["auto_offload_threshold"], 1000);
431 }
432
433 #[test]
434 fn test_toolkit_filter_enable_serialization() {
435 let filter = ToolkitFilter::Enable(vec!["github".to_string(), "gmail".to_string()]);
436 let json = serde_json::to_string(&filter).unwrap();
437 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
438
439 assert!(parsed.is_array());
440 assert_eq!(parsed.as_array().unwrap().len(), 2);
441 }
442
443 #[test]
444 fn test_toolkit_filter_disable_serialization() {
445 let filter = ToolkitFilter::Disable {
446 disable: vec!["exa".to_string()],
447 };
448 let json = serde_json::to_string(&filter).unwrap();
449 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
450
451 assert!(parsed.is_object());
452 assert!(parsed["disable"].is_array());
453 }
454
455 #[test]
456 fn test_tool_filter_enable_serialization() {
457 let filter = ToolFilter::Enable {
458 enable: vec!["GITHUB_CREATE_ISSUE".to_string()],
459 };
460 let json = serde_json::to_string(&filter).unwrap();
461 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
462
463 assert!(parsed.is_object());
464 assert!(parsed["enable"].is_array());
465 }
466
467 #[test]
468 fn test_tool_filter_disable_serialization() {
469 let filter = ToolFilter::Disable {
470 disable: vec!["GITHUB_DELETE_REPO".to_string()],
471 };
472 let json = serde_json::to_string(&filter).unwrap();
473 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
474
475 assert!(parsed.is_object());
476 assert!(parsed["disable"].is_array());
477 }
478
479 #[test]
480 fn test_tool_filter_enable_list_serialization() {
481 let filter = ToolFilter::EnableList(vec!["GITHUB_CREATE_ISSUE".to_string()]);
482 let json = serde_json::to_string(&filter).unwrap();
483 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
484
485 assert!(parsed.is_array());
486 }
487
488 #[test]
489 fn test_tool_execution_request_serialization() {
490 let request = ToolExecutionRequest {
491 tool_slug: "GITHUB_CREATE_ISSUE".to_string(),
492 arguments: Some(serde_json::json!({
493 "owner": "composio",
494 "repo": "composio",
495 "title": "Test issue"
496 })),
497 };
498
499 let json = serde_json::to_string(&request).unwrap();
500 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
501
502 assert_eq!(parsed["tool_slug"], "GITHUB_CREATE_ISSUE");
503 assert!(parsed["arguments"].is_object());
504 assert_eq!(parsed["arguments"]["owner"], "composio");
505 }
506
507 #[test]
508 fn test_tool_execution_request_without_arguments() {
509 let request = ToolExecutionRequest {
510 tool_slug: "GITHUB_GET_USER".to_string(),
511 arguments: None,
512 };
513
514 let json = serde_json::to_string(&request).unwrap();
515 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
516
517 assert_eq!(parsed["tool_slug"], "GITHUB_GET_USER");
518 assert!(parsed.get("arguments").is_none());
519 }
520
521 #[test]
522 fn test_meta_tool_execution_request_serialization() {
523 let request = MetaToolExecutionRequest {
524 slug: MetaToolSlug::ComposioSearchTools,
525 arguments: Some(serde_json::json!({
526 "query": "create a GitHub issue"
527 })),
528 };
529
530 let json = serde_json::to_string(&request).unwrap();
531 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
532
533 assert_eq!(parsed["slug"], "COMPOSIO_SEARCH_TOOLS");
534 assert!(parsed["arguments"].is_object());
535 }
536
537 #[test]
538 fn test_link_request_serialization() {
539 let request = LinkRequest {
540 toolkit: "github".to_string(),
541 callback_url: Some("https://example.com/callback".to_string()),
542 };
543
544 let json = serde_json::to_string(&request).unwrap();
545 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
546
547 assert_eq!(parsed["toolkit"], "github");
548 assert_eq!(parsed["callback_url"], "https://example.com/callback");
549 }
550
551 #[test]
552 fn test_link_request_without_callback() {
553 let request = LinkRequest {
554 toolkit: "gmail".to_string(),
555 callback_url: None,
556 };
557
558 let json = serde_json::to_string(&request).unwrap();
559 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
560
561 assert_eq!(parsed["toolkit"], "gmail");
562 assert!(parsed.get("callback_url").is_none());
563 }
564
565 #[test]
566 fn test_tags_config_serialization() {
567 let config = TagsConfig {
568 enabled: Some(vec![TagType::ReadOnlyHint, TagType::IdempotentHint]),
569 disabled: Some(vec![TagType::DestructiveHint]),
570 };
571
572 let json = serde_json::to_string(&config).unwrap();
573 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
574
575 assert!(parsed["enabled"].is_array());
576 assert!(parsed["disabled"].is_array());
577 assert_eq!(parsed["enabled"].as_array().unwrap().len(), 2);
578 assert_eq!(parsed["disabled"].as_array().unwrap().len(), 1);
579 }
580
581 #[test]
582 fn test_workbench_config_serialization() {
583 let config = WorkbenchConfig {
584 proxy_execution: Some(true),
585 auto_offload_threshold: Some(500),
586 };
587
588 let json = serde_json::to_string(&config).unwrap();
589 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
590
591 assert_eq!(parsed["proxy_execution"], true);
592 assert_eq!(parsed["auto_offload_threshold"], 500);
593 }
594
595 #[test]
596 fn test_workbench_config_partial_serialization() {
597 let config = WorkbenchConfig {
598 proxy_execution: Some(false),
599 auto_offload_threshold: None,
600 };
601
602 let json = serde_json::to_string(&config).unwrap();
603 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
604
605 assert_eq!(parsed["proxy_execution"], false);
606 assert!(parsed.get("auto_offload_threshold").is_none());
607 }
608}