use crate::client::ComposioClient;
use crate::error::ComposioError;
use crate::models::request::{
SessionConfig, TagsConfig, ToolFilter, ToolsConfig, ToolkitFilter, WorkbenchConfig,
};
use crate::models::response::ToolSchema;
use crate::models::enums::TagType;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Session {
client: Arc<ComposioClient>,
session_id: String,
mcp_url: String,
tools: Vec<String>,
}
impl Session {
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn mcp_url(&self) -> &str {
&self.mcp_url
}
pub fn tools(&self) -> &[String] {
&self.tools
}
pub(crate) fn from_response(
client: ComposioClient,
response: crate::models::response::SessionResponse,
) -> Self {
Self {
client: Arc::new(client),
session_id: response.session_id,
mcp_url: response.mcp.url,
tools: response.tool_router_tools,
}
}
pub async fn execute_tool(
&self,
tool_slug: impl Into<String>,
arguments: serde_json::Value,
) -> Result<crate::models::response::ToolExecutionResponse, ComposioError> {
use crate::models::request::ToolExecutionRequest;
use crate::models::response::ToolExecutionResponse;
use crate::retry::with_retry;
let tool_slug = tool_slug.into();
let url = format!(
"{}/tool_router/session/{}/execute",
self.client.config().base_url,
self.session_id
);
let request_body = ToolExecutionRequest {
tool_slug: tool_slug.clone(),
arguments: Some(arguments),
};
let policy = &self.client.config().retry_policy;
let response = with_retry(policy, || {
let url = url.clone();
let request_body = request_body.clone();
let client = self.client.http_client().clone();
async move {
let response = client
.post(&url)
.json(&request_body)
.send()
.await
.map_err(ComposioError::NetworkError)?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
Ok(response)
}
})
.await?;
let execution_response: ToolExecutionResponse = response
.json()
.await
.map_err(ComposioError::NetworkError)?;
Ok(execution_response)
}
pub async fn execute_meta_tool(
&self,
slug: crate::models::enums::MetaToolSlug,
arguments: serde_json::Value,
) -> Result<crate::models::response::MetaToolExecutionResponse, ComposioError> {
use crate::models::request::MetaToolExecutionRequest;
use crate::models::response::MetaToolExecutionResponse;
use crate::retry::with_retry;
let url = format!(
"{}/tool_router/session/{}/execute_meta",
self.client.config().base_url,
self.session_id
);
let request_body = MetaToolExecutionRequest {
slug,
arguments: Some(arguments),
};
let policy = &self.client.config().retry_policy;
let response = with_retry(policy, || {
let url = url.clone();
let request_body = request_body.clone();
let client = self.client.http_client().clone();
async move {
let response = client
.post(&url)
.json(&request_body)
.send()
.await
.map_err(ComposioError::NetworkError)?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
Ok(response)
}
})
.await?;
let execution_response: MetaToolExecutionResponse = response
.json()
.await
.map_err(ComposioError::NetworkError)?;
Ok(execution_response)
}
pub fn list_toolkits(&self) -> ToolkitListBuilder<'_> {
ToolkitListBuilder::new(self)
}
pub async fn get_meta_tools(&self) -> Result<Vec<ToolSchema>, ComposioError> {
use crate::retry::with_retry;
let url = format!(
"{}/tool_router/session/{}/tools",
self.client.config().base_url,
self.session_id
);
let policy = &self.client.config().retry_policy;
let response = with_retry(policy, || {
let url = url.clone();
let client = self.client.http_client().clone();
async move {
let response = client
.get(&url)
.send()
.await
.map_err(ComposioError::NetworkError)?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
Ok(response)
}
})
.await?;
let tools: Vec<ToolSchema> = response
.json()
.await
.map_err(ComposioError::NetworkError)?;
Ok(tools)
}
pub async fn create_auth_link(
&self,
toolkit: impl Into<String>,
callback_url: Option<String>,
) -> Result<crate::models::response::LinkResponse, ComposioError> {
use crate::models::request::LinkRequest;
use crate::retry::with_retry;
let toolkit = toolkit.into();
let url = format!(
"{}/tool_router/session/{}/link",
self.client.config().base_url,
self.session_id
);
let request_body = LinkRequest {
toolkit: toolkit.clone(),
callback_url,
};
let policy = &self.client.config().retry_policy;
let response = with_retry(policy, || {
let url = url.clone();
let client = self.client.http_client().clone();
let body = request_body.clone();
async move {
let response = client
.post(&url)
.json(&body)
.send()
.await
.map_err(ComposioError::NetworkError)?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
Ok(response)
}
})
.await?;
let link_response: crate::models::response::LinkResponse = response
.json()
.await
.map_err(ComposioError::NetworkError)?;
Ok(link_response)
}
}
pub struct SessionBuilder<'a> {
client: &'a ComposioClient,
#[allow(dead_code)]
user_id: String,
config: SessionConfig,
}
impl<'a> SessionBuilder<'a> {
pub fn new(client: &'a ComposioClient, user_id: String) -> Self {
Self {
client,
user_id: user_id.clone(),
config: SessionConfig {
user_id,
toolkits: None,
auth_configs: None,
connected_accounts: None,
manage_connections: None,
tools: None,
tags: None,
workbench: None,
},
}
}
pub fn toolkits(mut self, toolkits: Vec<impl Into<String>>) -> Self {
self.config.toolkits = Some(ToolkitFilter::Enable(
toolkits.into_iter().map(|t| t.into()).collect(),
));
self
}
pub fn disable_toolkits(mut self, toolkits: Vec<impl Into<String>>) -> Self {
self.config.toolkits = Some(ToolkitFilter::Disable {
disable: toolkits.into_iter().map(|t| t.into()).collect(),
});
self
}
pub fn auth_config(
mut self,
toolkit: impl Into<String>,
auth_config_id: impl Into<String>,
) -> Self {
self.config
.auth_configs
.get_or_insert_with(HashMap::new)
.insert(toolkit.into(), auth_config_id.into());
self
}
pub fn connected_account(
mut self,
toolkit: impl Into<String>,
connected_account_id: impl Into<String>,
) -> Self {
self.config
.connected_accounts
.get_or_insert_with(HashMap::new)
.insert(toolkit.into(), connected_account_id.into());
self
}
pub fn manage_connections(mut self, enabled: bool) -> Self {
self.config.manage_connections = Some(crate::models::ManageConnectionsConfig::Bool(enabled));
self
}
pub fn tools(mut self, toolkit: impl Into<String>, tools: Vec<impl Into<String>>) -> Self {
let tool_filter = ToolFilter::EnableList(tools.into_iter().map(|t| t.into()).collect());
self.config
.tools
.get_or_insert_with(|| ToolsConfig(HashMap::new()))
.0
.insert(toolkit.into(), tool_filter);
self
}
pub fn tags(
mut self,
enabled: Option<Vec<TagType>>,
disabled: Option<Vec<TagType>>,
) -> Self {
self.config.tags = Some(TagsConfig { enabled, disabled });
self
}
pub fn workbench(
mut self,
proxy_execution: Option<bool>,
auto_offload_threshold: Option<u32>,
) -> Self {
self.config.workbench = Some(WorkbenchConfig {
proxy_execution,
auto_offload_threshold,
});
self
}
pub async fn send(self) -> Result<Session, ComposioError> {
use crate::models::response::SessionResponse;
use crate::retry::with_retry;
let url = format!("{}/tool_router/session", self.client.config().base_url);
let policy = &self.client.config().retry_policy;
let response = with_retry(policy, || {
let url = url.clone();
let config = self.config.clone();
let client = self.client.http_client().clone();
async move {
let response = client
.post(&url)
.json(&config)
.send()
.await
.map_err(ComposioError::NetworkError)?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
Ok(response)
}
})
.await?;
let session_response: SessionResponse = response
.json()
.await
.map_err(ComposioError::NetworkError)?;
Ok(Session {
client: Arc::new(self.client.clone()),
session_id: session_response.session_id,
mcp_url: session_response.mcp.url,
tools: session_response.tool_router_tools,
})
}
}
#[derive(Debug)]
pub struct ToolkitListBuilder<'a> {
session: &'a Session,
limit: Option<u32>,
cursor: Option<String>,
toolkits: Option<Vec<String>>,
is_connected: Option<bool>,
search: Option<String>,
}
impl<'a> ToolkitListBuilder<'a> {
fn new(session: &'a Session) -> Self {
Self {
session,
limit: None,
cursor: None,
toolkits: None,
is_connected: None,
search: None,
}
}
pub fn limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
pub fn cursor(mut self, cursor: impl Into<String>) -> Self {
self.cursor = Some(cursor.into());
self
}
pub fn toolkits(mut self, toolkits: Vec<impl Into<String>>) -> Self {
self.toolkits = Some(toolkits.into_iter().map(|t| t.into()).collect());
self
}
pub fn is_connected(mut self, is_connected: bool) -> Self {
self.is_connected = Some(is_connected);
self
}
pub fn search(mut self, search: impl Into<String>) -> Self {
self.search = Some(search.into());
self
}
pub async fn send(self) -> Result<crate::models::response::ToolkitListResponse, ComposioError> {
use crate::models::response::ToolkitListResponse;
use crate::retry::with_retry;
let url = format!(
"{}/tool_router/session/{}/toolkits",
self.session.client.config().base_url,
self.session.session_id
);
let mut query_params = Vec::new();
if let Some(limit) = self.limit {
query_params.push(("limit", limit.to_string()));
}
if let Some(cursor) = &self.cursor {
query_params.push(("cursor", cursor.clone()));
}
if let Some(toolkits) = &self.toolkits {
query_params.push(("toolkits", toolkits.join(",")));
}
if let Some(is_connected) = self.is_connected {
query_params.push(("is_connected", is_connected.to_string()));
}
if let Some(search) = &self.search {
query_params.push(("search", search.clone()));
}
let policy = &self.session.client.config().retry_policy;
let response = with_retry(policy, || {
let url = url.clone();
let query_params = query_params.clone();
let client = self.session.client.http_client().clone();
async move {
let response = client
.get(&url)
.query(&query_params)
.send()
.await
.map_err(ComposioError::NetworkError)?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
Ok(response)
}
})
.await?;
let toolkit_response: ToolkitListResponse = response
.json()
.await
.map_err(ComposioError::NetworkError)?;
Ok(toolkit_response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::ComposioClient;
use crate::models::enums::TagType;
use crate::models::request::{ManageConnectionsConfig, ToolFilter, ToolkitFilter};
fn create_test_client() -> ComposioClient {
ComposioClient::builder()
.api_key("test_api_key")
.build()
.unwrap()
}
#[test]
fn test_session_builder_new() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string());
assert_eq!(builder.user_id, "user_123");
assert!(builder.config.toolkits.is_none());
assert!(builder.config.auth_configs.is_none());
assert!(builder.config.connected_accounts.is_none());
assert!(builder.config.manage_connections.is_none());
assert!(builder.config.tools.is_none());
assert!(builder.config.tags.is_none());
assert!(builder.config.workbench.is_none());
}
#[test]
fn test_session_builder_toolkits_enable() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.toolkits(vec!["github", "gmail"]);
match builder.config.toolkits {
Some(ToolkitFilter::Enable(toolkits)) => {
assert_eq!(toolkits.len(), 2);
assert!(toolkits.contains(&"github".to_string()));
assert!(toolkits.contains(&"gmail".to_string()));
}
_ => panic!("Expected Enable variant"),
}
}
#[test]
fn test_session_builder_disable_toolkits() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.disable_toolkits(vec!["exa", "firecrawl"]);
match builder.config.toolkits {
Some(ToolkitFilter::Disable { disable }) => {
assert_eq!(disable.len(), 2);
assert!(disable.contains(&"exa".to_string()));
assert!(disable.contains(&"firecrawl".to_string()));
}
_ => panic!("Expected Disable variant"),
}
}
#[test]
fn test_session_builder_auth_config() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.auth_config("github", "ac_custom_config");
let auth_configs = builder.config.auth_configs.unwrap();
assert_eq!(auth_configs.len(), 1);
assert_eq!(auth_configs.get("github"), Some(&"ac_custom_config".to_string()));
}
#[test]
fn test_session_builder_multiple_auth_configs() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.auth_config("github", "ac_github_config")
.auth_config("gmail", "ac_gmail_config");
let auth_configs = builder.config.auth_configs.unwrap();
assert_eq!(auth_configs.len(), 2);
assert_eq!(auth_configs.get("github"), Some(&"ac_github_config".to_string()));
assert_eq!(auth_configs.get("gmail"), Some(&"ac_gmail_config".to_string()));
}
#[test]
fn test_session_builder_connected_account() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.connected_account("gmail", "ca_work_gmail");
let connected_accounts = builder.config.connected_accounts.unwrap();
assert_eq!(connected_accounts.len(), 1);
assert_eq!(connected_accounts.get("gmail"), Some(&"ca_work_gmail".to_string()));
}
#[test]
fn test_session_builder_multiple_connected_accounts() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.connected_account("gmail", "ca_work_gmail")
.connected_account("github", "ca_personal_github");
let connected_accounts = builder.config.connected_accounts.unwrap();
assert_eq!(connected_accounts.len(), 2);
assert_eq!(connected_accounts.get("gmail"), Some(&"ca_work_gmail".to_string()));
assert_eq!(connected_accounts.get("github"), Some(&"ca_personal_github".to_string()));
}
#[test]
fn test_session_builder_manage_connections_true() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.manage_connections(true);
match builder.config.manage_connections {
Some(ManageConnectionsConfig::Bool(enabled)) => {
assert!(enabled);
}
_ => panic!("Expected Bool variant with true"),
}
}
#[test]
fn test_session_builder_manage_connections_false() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.manage_connections(false);
match builder.config.manage_connections {
Some(ManageConnectionsConfig::Bool(enabled)) => {
assert!(!enabled);
}
_ => panic!("Expected Bool variant with false"),
}
}
#[test]
fn test_session_builder_tools_enable() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.tools("github", vec!["GITHUB_CREATE_ISSUE", "GITHUB_GET_REPOS"]);
let tools_config = builder.config.tools.unwrap();
let github_filter = tools_config.0.get("github").unwrap();
match github_filter {
ToolFilter::EnableList(tools) => {
assert_eq!(tools.len(), 2);
assert!(tools.contains(&"GITHUB_CREATE_ISSUE".to_string()));
assert!(tools.contains(&"GITHUB_GET_REPOS".to_string()));
}
_ => panic!("Expected EnableList variant"),
}
}
#[test]
fn test_session_builder_multiple_toolkit_tools() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.tools("github", vec!["GITHUB_CREATE_ISSUE"])
.tools("gmail", vec!["GMAIL_SEND_EMAIL"]);
let tools_config = builder.config.tools.unwrap();
assert_eq!(tools_config.0.len(), 2);
assert!(tools_config.0.contains_key("github"));
assert!(tools_config.0.contains_key("gmail"));
}
#[test]
fn test_session_builder_tags_enabled() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.tags(Some(vec![TagType::ReadOnlyHint, TagType::IdempotentHint]), None);
let tags_config = builder.config.tags.unwrap();
let enabled = tags_config.enabled.unwrap();
assert_eq!(enabled.len(), 2);
assert!(enabled.contains(&TagType::ReadOnlyHint));
assert!(enabled.contains(&TagType::IdempotentHint));
assert!(tags_config.disabled.is_none());
}
#[test]
fn test_session_builder_tags_disabled() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.tags(None, Some(vec![TagType::DestructiveHint]));
let tags_config = builder.config.tags.unwrap();
let disabled = tags_config.disabled.unwrap();
assert_eq!(disabled.len(), 1);
assert!(disabled.contains(&TagType::DestructiveHint));
assert!(tags_config.enabled.is_none());
}
#[test]
fn test_session_builder_tags_both() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.tags(
Some(vec![TagType::ReadOnlyHint]),
Some(vec![TagType::DestructiveHint])
);
let tags_config = builder.config.tags.unwrap();
assert!(tags_config.enabled.is_some());
assert!(tags_config.disabled.is_some());
}
#[test]
fn test_session_builder_workbench() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.workbench(Some(true), Some(1000));
let workbench_config = builder.config.workbench.unwrap();
assert_eq!(workbench_config.proxy_execution, Some(true));
assert_eq!(workbench_config.auto_offload_threshold, Some(1000));
}
#[test]
fn test_session_builder_workbench_no_threshold() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.workbench(Some(false), None);
let workbench_config = builder.config.workbench.unwrap();
assert_eq!(workbench_config.proxy_execution, Some(false));
assert_eq!(workbench_config.auto_offload_threshold, None);
}
#[test]
fn test_session_builder_method_chaining() {
let client = create_test_client();
let builder = SessionBuilder::new(&client, "user_123".to_string())
.toolkits(vec!["github", "gmail"])
.auth_config("github", "ac_custom")
.connected_account("gmail", "ca_work")
.manage_connections(true)
.tools("github", vec!["GITHUB_CREATE_ISSUE"])
.tags(Some(vec![TagType::ReadOnlyHint]), None)
.workbench(Some(true), Some(500));
assert!(builder.config.toolkits.is_some());
assert!(builder.config.auth_configs.is_some());
assert!(builder.config.connected_accounts.is_some());
assert!(builder.config.manage_connections.is_some());
assert!(builder.config.tools.is_some());
assert!(builder.config.tags.is_some());
assert!(builder.config.workbench.is_some());
}
#[test]
fn test_session_session_id_accessor() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec!["COMPOSIO_SEARCH_TOOLS".to_string()],
};
assert_eq!(session.session_id(), "sess_123");
}
#[test]
fn test_session_mcp_url_accessor() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec!["COMPOSIO_SEARCH_TOOLS".to_string()],
};
assert_eq!(session.mcp_url(), "https://mcp.composio.dev");
}
#[test]
fn test_session_tools_accessor() {
let client = Arc::new(create_test_client());
let tools = vec![
"COMPOSIO_SEARCH_TOOLS".to_string(),
"COMPOSIO_MULTI_EXECUTE_TOOL".to_string(),
];
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: tools.clone(),
};
assert_eq!(session.tools(), &tools);
assert_eq!(session.tools().len(), 2);
}
#[test]
fn test_toolkit_list_builder_new() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec![],
};
let builder = ToolkitListBuilder::new(&session);
assert!(builder.limit.is_none());
assert!(builder.cursor.is_none());
assert!(builder.toolkits.is_none());
assert!(builder.is_connected.is_none());
assert!(builder.search.is_none());
}
#[test]
fn test_toolkit_list_builder_limit() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec![],
};
let builder = session.list_toolkits().limit(50);
assert_eq!(builder.limit, Some(50));
}
#[test]
fn test_toolkit_list_builder_cursor() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec![],
};
let builder = session.list_toolkits().cursor("cursor_abc");
assert_eq!(builder.cursor, Some("cursor_abc".to_string()));
}
#[test]
fn test_toolkit_list_builder_toolkits() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec![],
};
let builder = session.list_toolkits().toolkits(vec!["github", "gmail"]);
let toolkits = builder.toolkits.unwrap();
assert_eq!(toolkits.len(), 2);
assert!(toolkits.contains(&"github".to_string()));
assert!(toolkits.contains(&"gmail".to_string()));
}
#[test]
fn test_toolkit_list_builder_is_connected() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec![],
};
let builder = session.list_toolkits().is_connected(true);
assert_eq!(builder.is_connected, Some(true));
}
#[test]
fn test_toolkit_list_builder_search() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec![],
};
let builder = session.list_toolkits().search("communication");
assert_eq!(builder.search, Some("communication".to_string()));
}
#[test]
fn test_toolkit_list_builder_method_chaining() {
let client = Arc::new(create_test_client());
let session = Session {
client,
session_id: "sess_123".to_string(),
mcp_url: "https://mcp.composio.dev".to_string(),
tools: vec![],
};
let builder = session.list_toolkits()
.limit(25)
.cursor("cursor_xyz")
.toolkits(vec!["github"])
.is_connected(true)
.search("git");
assert_eq!(builder.limit, Some(25));
assert_eq!(builder.cursor, Some("cursor_xyz".to_string()));
assert!(builder.toolkits.is_some());
assert_eq!(builder.is_connected, Some(true));
assert_eq!(builder.search, Some("git".to_string()));
}
}