use crate::client::ComposioClient;
use crate::error::ComposioError;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ConnectionStatus {
Active,
Initiated,
Expired,
Failed,
Inactive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectedAccount {
pub id: String,
pub toolkit: String,
pub status: ConnectionStatus,
pub user_id: String,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthLink {
pub redirect_url: String,
pub link_token: String,
pub expires_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub connected_account_id: Option<String>,
}
pub struct ConnectionManager {
client: Arc<ComposioClient>,
}
impl ConnectionManager {
pub fn new(client: Arc<ComposioClient>) -> Self {
Self { client }
}
pub async fn list_connections(
&self,
session_id: &str,
) -> Result<Vec<ConnectedAccount>, ComposioError> {
let url = format!(
"{}/tool_router/session/{}/toolkits",
self.client.config().base_url,
session_id
);
let response = self
.client
.http_client()
.get(&url)
.send()
.await?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
let data: serde_json::Value = response.json().await?;
let accounts = data["data"]["items"]
.as_array()
.ok_or_else(|| {
ComposioError::InvalidInput("Invalid response format".to_string())
})?
.iter()
.filter_map(|item| {
item["connected_account"]
.as_object()
.and_then(|acc| serde_json::from_value(serde_json::Value::Object(acc.clone())).ok())
})
.collect();
Ok(accounts)
}
pub async fn create_auth_link(
&self,
session_id: &str,
toolkit: &str,
callback_url: Option<&str>,
) -> Result<AuthLink, ComposioError> {
let url = format!(
"{}/tool_router/session/{}/link",
self.client.config().base_url,
session_id
);
let mut body = serde_json::json!({
"toolkit": toolkit,
});
if let Some(callback) = callback_url {
body["callback_url"] = serde_json::json!(callback);
}
let response = self
.client
.http_client()
.post(&url)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
let data: serde_json::Value = response.json().await?;
let link: AuthLink = serde_json::from_value(data["data"].clone())?;
Ok(link)
}
pub async fn is_connected(&self, session_id: &str, toolkit: &str) -> Result<bool, ComposioError> {
let accounts = self.list_connections(session_id).await?;
Ok(accounts
.iter()
.any(|acc| acc.toolkit == toolkit && acc.status == ConnectionStatus::Active))
}
pub async fn get_connection_status(
&self,
session_id: &str,
toolkit: &str,
) -> Result<Option<ConnectionStatus>, ComposioError> {
let accounts = self.list_connections(session_id).await?;
Ok(accounts
.iter()
.find(|acc| acc.toolkit == toolkit)
.map(|acc| acc.status.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_status_serialization() {
let status = ConnectionStatus::Active;
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, "\"ACTIVE\"");
let deserialized: ConnectionStatus = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, ConnectionStatus::Active);
}
#[test]
fn test_connected_account_serialization() {
let account = ConnectedAccount {
id: "ca_123".to_string(),
toolkit: "github".to_string(),
status: ConnectionStatus::Active,
user_id: "user_123".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
updated_at: "2024-01-01T00:00:00Z".to_string(),
};
let json = serde_json::to_string(&account).unwrap();
assert!(json.contains("ca_123"));
assert!(json.contains("github"));
assert!(json.contains("ACTIVE"));
let deserialized: ConnectedAccount = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, "ca_123");
assert_eq!(deserialized.status, ConnectionStatus::Active);
}
#[test]
fn test_auth_link_serialization() {
let link = AuthLink {
redirect_url: "https://auth.composio.dev/...".to_string(),
link_token: "lt_abc123".to_string(),
expires_at: "2024-01-01T01:00:00Z".to_string(),
connected_account_id: Some("ca_123".to_string()),
};
let json = serde_json::to_string(&link).unwrap();
assert!(json.contains("redirect_url"));
assert!(json.contains("lt_abc123"));
assert!(json.contains("ca_123"));
}
#[test]
fn test_auth_link_without_account_id() {
let link = AuthLink {
redirect_url: "https://auth.composio.dev/...".to_string(),
link_token: "lt_abc123".to_string(),
expires_at: "2024-01-01T01:00:00Z".to_string(),
connected_account_id: None,
};
let json = serde_json::to_string(&link).unwrap();
assert!(!json.contains("connected_account_id"));
}
}