use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::enums::{MetaToolSlug, TagType};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
pub user_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub toolkits: Option<ToolkitFilter>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth_configs: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub connected_accounts: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub manage_connections: Option<ManageConnectionsConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<TagsConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workbench: Option<WorkbenchConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ManageConnectionsConfig {
Bool(bool),
Detailed {
enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
enable_wait_for_connections: Option<bool>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolkitFilter {
Enable(Vec<String>),
Disable { disable: Vec<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolsConfig(pub HashMap<String, ToolFilter>);
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolFilter {
Enable { enable: Vec<String> },
Disable { disable: Vec<String> },
EnableList(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TagsConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<Vec<TagType>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub disabled: Option<Vec<TagType>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkbenchConfig {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(alias = "proxy_execution_enabled")]
pub proxy_execution: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_offload_threshold: Option<u32>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ToolExecutionRequest {
pub tool_slug: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize)]
pub struct MetaToolExecutionRequest {
pub slug: MetaToolSlug,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize)]
pub struct LinkRequest {
pub toolkit: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub callback_url: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_session_config_minimal_serialization() {
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: None,
auth_configs: None,
connected_accounts: None,
manage_connections: None,
tools: None,
tags: None,
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("user_123"));
assert!(!json.contains("toolkits"));
assert!(!json.contains("auth_configs"));
}
#[test]
fn test_session_config_with_toolkits_enable() {
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: Some(ToolkitFilter::Enable(vec!["github".to_string(), "gmail".to_string()])),
auth_configs: None,
connected_accounts: None,
manage_connections: None,
tools: None,
tags: None,
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed["toolkits"].is_array());
let toolkits = parsed["toolkits"].as_array().unwrap();
assert_eq!(toolkits.len(), 2);
}
#[test]
fn test_session_config_with_toolkits_disable() {
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: Some(ToolkitFilter::Disable {
disable: vec!["exa".to_string(), "firecrawl".to_string()],
}),
auth_configs: None,
connected_accounts: None,
manage_connections: None,
tools: None,
tags: None,
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed["toolkits"].is_object());
assert!(parsed["toolkits"]["disable"].is_array());
}
#[test]
fn test_session_config_with_auth_configs() {
let mut auth_configs = HashMap::new();
auth_configs.insert("github".to_string(), "ac_custom".to_string());
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: None,
auth_configs: Some(auth_configs),
connected_accounts: None,
manage_connections: None,
tools: None,
tags: None,
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["auth_configs"]["github"], "ac_custom");
}
#[test]
fn test_session_config_with_manage_connections_bool() {
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: None,
auth_configs: None,
connected_accounts: None,
manage_connections: Some(ManageConnectionsConfig::Bool(true)),
tools: None,
tags: None,
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["manage_connections"], true);
}
#[test]
fn test_session_config_with_manage_connections_detailed() {
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: None,
auth_configs: None,
connected_accounts: None,
manage_connections: Some(ManageConnectionsConfig::Detailed {
enabled: true,
enable_wait_for_connections: Some(false),
}),
tools: None,
tags: None,
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["manage_connections"]["enabled"], true);
assert_eq!(parsed["manage_connections"]["enable_wait_for_connections"], false);
}
#[test]
fn test_session_config_with_tools() {
let mut tools_map = HashMap::new();
tools_map.insert(
"github".to_string(),
ToolFilter::EnableList(vec!["GITHUB_CREATE_ISSUE".to_string()]),
);
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: None,
auth_configs: None,
connected_accounts: None,
manage_connections: None,
tools: Some(ToolsConfig(tools_map)),
tags: None,
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed["tools"]["github"].is_array());
}
#[test]
fn test_session_config_with_tags() {
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: None,
auth_configs: None,
connected_accounts: None,
manage_connections: None,
tools: None,
tags: Some(TagsConfig {
enabled: Some(vec![TagType::ReadOnlyHint]),
disabled: Some(vec![TagType::DestructiveHint]),
}),
workbench: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed["tags"]["enabled"].is_array());
assert!(parsed["tags"]["disabled"].is_array());
}
#[test]
fn test_session_config_with_workbench() {
let config = SessionConfig {
user_id: "user_123".to_string(),
toolkits: None,
auth_configs: None,
connected_accounts: None,
manage_connections: None,
tools: None,
tags: None,
workbench: Some(WorkbenchConfig {
proxy_execution: Some(true),
auto_offload_threshold: Some(1000),
}),
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["workbench"]["proxy_execution"], true);
assert_eq!(parsed["workbench"]["auto_offload_threshold"], 1000);
}
#[test]
fn test_toolkit_filter_enable_serialization() {
let filter = ToolkitFilter::Enable(vec!["github".to_string(), "gmail".to_string()]);
let json = serde_json::to_string(&filter).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_array());
assert_eq!(parsed.as_array().unwrap().len(), 2);
}
#[test]
fn test_toolkit_filter_disable_serialization() {
let filter = ToolkitFilter::Disable {
disable: vec!["exa".to_string()],
};
let json = serde_json::to_string(&filter).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_object());
assert!(parsed["disable"].is_array());
}
#[test]
fn test_tool_filter_enable_serialization() {
let filter = ToolFilter::Enable {
enable: vec!["GITHUB_CREATE_ISSUE".to_string()],
};
let json = serde_json::to_string(&filter).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_object());
assert!(parsed["enable"].is_array());
}
#[test]
fn test_tool_filter_disable_serialization() {
let filter = ToolFilter::Disable {
disable: vec!["GITHUB_DELETE_REPO".to_string()],
};
let json = serde_json::to_string(&filter).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_object());
assert!(parsed["disable"].is_array());
}
#[test]
fn test_tool_filter_enable_list_serialization() {
let filter = ToolFilter::EnableList(vec!["GITHUB_CREATE_ISSUE".to_string()]);
let json = serde_json::to_string(&filter).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_array());
}
#[test]
fn test_tool_execution_request_serialization() {
let request = ToolExecutionRequest {
tool_slug: "GITHUB_CREATE_ISSUE".to_string(),
arguments: Some(serde_json::json!({
"owner": "composio",
"repo": "composio",
"title": "Test issue"
})),
};
let json = serde_json::to_string(&request).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["tool_slug"], "GITHUB_CREATE_ISSUE");
assert!(parsed["arguments"].is_object());
assert_eq!(parsed["arguments"]["owner"], "composio");
}
#[test]
fn test_tool_execution_request_without_arguments() {
let request = ToolExecutionRequest {
tool_slug: "GITHUB_GET_USER".to_string(),
arguments: None,
};
let json = serde_json::to_string(&request).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["tool_slug"], "GITHUB_GET_USER");
assert!(parsed.get("arguments").is_none());
}
#[test]
fn test_meta_tool_execution_request_serialization() {
let request = MetaToolExecutionRequest {
slug: MetaToolSlug::ComposioSearchTools,
arguments: Some(serde_json::json!({
"query": "create a GitHub issue"
})),
};
let json = serde_json::to_string(&request).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["slug"], "COMPOSIO_SEARCH_TOOLS");
assert!(parsed["arguments"].is_object());
}
#[test]
fn test_link_request_serialization() {
let request = LinkRequest {
toolkit: "github".to_string(),
callback_url: Some("https://example.com/callback".to_string()),
};
let json = serde_json::to_string(&request).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["toolkit"], "github");
assert_eq!(parsed["callback_url"], "https://example.com/callback");
}
#[test]
fn test_link_request_without_callback() {
let request = LinkRequest {
toolkit: "gmail".to_string(),
callback_url: None,
};
let json = serde_json::to_string(&request).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["toolkit"], "gmail");
assert!(parsed.get("callback_url").is_none());
}
#[test]
fn test_tags_config_serialization() {
let config = TagsConfig {
enabled: Some(vec![TagType::ReadOnlyHint, TagType::IdempotentHint]),
disabled: Some(vec![TagType::DestructiveHint]),
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed["enabled"].is_array());
assert!(parsed["disabled"].is_array());
assert_eq!(parsed["enabled"].as_array().unwrap().len(), 2);
assert_eq!(parsed["disabled"].as_array().unwrap().len(), 1);
}
#[test]
fn test_workbench_config_serialization() {
let config = WorkbenchConfig {
proxy_execution: Some(true),
auto_offload_threshold: Some(500),
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["proxy_execution"], true);
assert_eq!(parsed["auto_offload_threshold"], 500);
}
#[test]
fn test_workbench_config_partial_serialization() {
let config = WorkbenchConfig {
proxy_execution: Some(false),
auto_offload_threshold: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["proxy_execution"], false);
assert!(parsed.get("auto_offload_threshold").is_none());
}
}