use std::collections::HashMap;
use async_trait::async_trait;
use tokio::sync::mpsc;
use crate::channels::relay::client::{ChannelEvent, RelayClient};
use crate::channels::{Channel, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate};
use crate::error::ChannelError;
pub const DEFAULT_RELAY_NAME: &str = "slack-relay";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelayProvider {
Slack,
}
impl RelayProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::Slack => "slack",
}
}
pub fn channel_name(&self) -> &'static str {
match self {
Self::Slack => DEFAULT_RELAY_NAME,
}
}
}
pub struct RelayChannel {
client: RelayClient,
provider: RelayProvider,
team_id: String,
instance_id: String,
event_tx: mpsc::Sender<ChannelEvent>,
event_rx: tokio::sync::Mutex<Option<mpsc::Receiver<ChannelEvent>>>,
}
impl RelayChannel {
pub fn new(
client: RelayClient,
team_id: String,
instance_id: String,
event_tx: mpsc::Sender<ChannelEvent>,
event_rx: mpsc::Receiver<ChannelEvent>,
) -> Self {
Self::new_with_provider(
client,
RelayProvider::Slack,
team_id,
instance_id,
event_tx,
event_rx,
)
}
pub fn new_with_provider(
client: RelayClient,
provider: RelayProvider,
team_id: String,
instance_id: String,
event_tx: mpsc::Sender<ChannelEvent>,
event_rx: mpsc::Receiver<ChannelEvent>,
) -> Self {
Self {
client,
provider,
team_id,
instance_id,
event_tx,
event_rx: tokio::sync::Mutex::new(Some(event_rx)),
}
}
pub fn event_sender(&self) -> mpsc::Sender<ChannelEvent> {
self.event_tx.clone()
}
fn build_send_body(
&self,
channel_id: &str,
text: &str,
thread_id: Option<&str>,
) -> (String, serde_json::Value) {
match self.provider {
RelayProvider::Slack => {
let mut body = serde_json::json!({
"channel": channel_id,
"text": text,
});
if let Some(tid) = thread_id {
body["thread_ts"] = serde_json::Value::String(tid.to_string());
}
("chat.postMessage".to_string(), body)
}
}
}
async fn proxy_send(
&self,
team_id: &str,
method: &str,
body: serde_json::Value,
) -> Result<serde_json::Value, crate::channels::relay::client::RelayError> {
self.client
.proxy_provider(self.provider.as_str(), team_id, method, body)
.await
}
}
#[async_trait]
impl Channel for RelayChannel {
fn name(&self) -> &str {
self.provider.channel_name()
}
async fn start(&self) -> Result<MessageStream, ChannelError> {
let channel_name = self.name().to_string();
let mut event_rx =
self.event_rx
.lock()
.await
.take()
.ok_or_else(|| ChannelError::StartupFailed {
name: channel_name.clone(),
reason: "RelayChannel already started".to_string(),
})?;
let (tx, rx) = mpsc::channel(64);
let provider_str = self.provider.as_str().to_string();
let relay_name = channel_name.clone();
tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
if event.sender_id.is_empty()
|| event.channel_id.is_empty()
|| event.provider_scope.is_empty()
{
tracing::debug!(
event_type = %event.event_type,
sender_id = %event.sender_id,
channel_id = %event.channel_id,
"Relay: skipping event with missing required fields"
);
continue;
}
if !event.is_message() {
tracing::debug!(
event_type = %event.event_type,
"Relay: skipping non-message event"
);
continue;
}
tracing::info!(
event_type = %event.event_type,
sender = %event.sender_id,
channel = %event.channel_id,
provider = %provider_str,
"Relay: received message from {}", provider_str
);
let msg = IncomingMessage::new(&relay_name, &event.sender_id, event.text())
.with_user_name(event.display_name())
.with_metadata(serde_json::json!({
"team_id": event.team_id(),
"channel_id": event.channel_id,
"sender_id": event.sender_id,
"sender_name": event.display_name(),
"event_type": event.event_type,
"thread_id": event.thread_id,
"provider": event.provider,
}));
let msg = if let Some(ref thread_id) = event.thread_id {
msg.with_thread(thread_id)
} else {
msg.with_thread(&event.channel_id)
};
if tx.send(msg).await.is_err() {
tracing::info!("Relay channel receiver dropped, stopping");
return;
}
}
tracing::info!("Relay event channel closed");
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
Ok(Box::pin(stream))
}
async fn respond(
&self,
msg: &IncomingMessage,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
let channel_name = self.name().to_string();
let metadata = &msg.metadata;
let team_id = metadata
.get("team_id")
.and_then(|v| v.as_str())
.unwrap_or(&self.team_id);
let channel_id = metadata
.get("channel_id")
.and_then(|v| v.as_str())
.ok_or_else(|| ChannelError::SendFailed {
name: channel_name.clone(),
reason: "Missing channel_id in message metadata".to_string(),
})?;
let thread_id = response
.thread_id
.as_deref()
.or_else(|| metadata.get("thread_id").and_then(|v| v.as_str()));
let (method, body) = self.build_send_body(channel_id, &response.content, thread_id);
self.proxy_send(team_id, &method, body)
.await
.map_err(|e| ChannelError::SendFailed {
name: channel_name,
reason: e.to_string(),
})?;
Ok(())
}
async fn send_status(
&self,
status: StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), ChannelError> {
let StatusUpdate::ApprovalNeeded {
request_id,
tool_name,
description,
parameters,
allow_always: _,
} = status
else {
return Ok(());
};
let event_type = metadata
.get("event_type")
.and_then(|v| v.as_str())
.unwrap_or("");
if event_type != "direct_message" {
tracing::warn!(
tool = %tool_name,
event_type,
"Approval requested in non-DM, skipping buttons"
);
return Ok(());
}
let channel_id = metadata
.get("channel_id")
.and_then(|v| v.as_str())
.ok_or_else(|| ChannelError::SendFailed {
name: self.name().to_string(),
reason: "Missing channel_id for approval buttons".into(),
})?;
let thread_id = metadata.get("thread_id").and_then(|v| v.as_str());
let team_id = metadata
.get("team_id")
.and_then(|v| v.as_str())
.unwrap_or(&self.team_id);
let approval_token = self
.client
.create_approval(team_id, channel_id, thread_id, &request_id)
.await
.map_err(|e| ChannelError::SendFailed {
name: self.name().to_string(),
reason: format!("Failed to register approval: {e}"),
})?;
let value_payload = serde_json::json!({
"approval_token": approval_token,
});
let value_str = value_payload.to_string();
let params_display =
serde_json::to_string_pretty(¶meters).unwrap_or_else(|_| parameters.to_string());
let blocks = serde_json::json!([
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": format!(
"*Tool approval required*\n`{tool_name}`: {description}\n```{params_display}```"
)
}
},
{
"type": "actions",
"elements": [
{
"type": "button",
"text": { "type": "plain_text", "text": "Approve" },
"style": "primary",
"action_id": "approve_tool",
"value": value_str,
},
{
"type": "button",
"text": { "type": "plain_text", "text": "Deny" },
"style": "danger",
"action_id": "deny_tool",
"value": value_str,
}
]
}
]);
let mut body = serde_json::json!({
"channel": channel_id,
"text": format!("Tool approval required: {tool_name} - {description}"),
"blocks": blocks,
});
if let Some(tid) = thread_id {
body["thread_ts"] = serde_json::Value::String(tid.to_string());
}
self.proxy_send(team_id, "chat.postMessage", body)
.await
.map_err(|e| ChannelError::SendFailed {
name: self.name().to_string(),
reason: e.to_string(),
})?;
Ok(())
}
async fn broadcast(
&self,
target: &str,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
let channel_name = self.name().to_string();
let thread_id = response
.thread_id
.as_deref()
.or_else(|| response.metadata.get("thread_ts").and_then(|v| v.as_str()));
let (method, body) = self.build_send_body(target, &response.content, thread_id);
self.proxy_send(&self.team_id, &method, body)
.await
.map_err(|e| ChannelError::SendFailed {
name: channel_name,
reason: e.to_string(),
})?;
Ok(())
}
async fn health_check(&self) -> Result<(), ChannelError> {
self.client
.list_connections(&self.instance_id)
.await
.map_err(|_| ChannelError::HealthCheckFailed {
name: self.name().to_string(),
})?;
Ok(())
}
fn conversation_context(&self, metadata: &serde_json::Value) -> HashMap<String, String> {
let mut ctx = HashMap::new();
if let Some(sender) = metadata.get("sender_name").and_then(|v| v.as_str()) {
ctx.insert("sender".to_string(), sender.to_string());
}
if let Some(sender_id) = metadata.get("sender_id").and_then(|v| v.as_str()) {
ctx.insert("sender_uuid".to_string(), sender_id.to_string());
}
if let Some(channel_id) = metadata.get("channel_id").and_then(|v| v.as_str()) {
ctx.insert("group".to_string(), channel_id.to_string());
}
ctx.insert("platform".to_string(), self.provider.as_str().to_string());
ctx
}
async fn shutdown(&self) -> Result<(), ChannelError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_client() -> RelayClient {
RelayClient::new(
"http://localhost:3001".into(),
secrecy::SecretString::from("key".to_string()),
30,
)
.expect("client")
}
fn make_channel() -> RelayChannel {
let (tx, rx) = mpsc::channel(64);
RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx, rx)
}
#[test]
fn relay_channel_name() {
let channel = make_channel();
assert_eq!(channel.name(), DEFAULT_RELAY_NAME);
}
#[test]
fn conversation_context_extracts_metadata() {
let channel = make_channel();
let metadata = serde_json::json!({
"sender_name": "bob",
"sender_id": "U123",
"channel_id": "C456",
});
let ctx = channel.conversation_context(&metadata);
assert_eq!(ctx.get("sender"), Some(&"bob".to_string()));
assert_eq!(ctx.get("sender_uuid"), Some(&"U123".to_string()));
assert_eq!(ctx.get("platform"), Some(&"slack".to_string()));
}
#[test]
fn metadata_shape_includes_event_type_and_sender_name() {
let metadata = serde_json::json!({
"team_id": "T123",
"channel_id": "C456",
"sender_id": "U789",
"sender_name": "alice",
"event_type": "direct_message",
"thread_id": null,
"provider": "slack",
});
assert_eq!(
metadata.get("event_type").and_then(|v| v.as_str()),
Some("direct_message")
);
assert_eq!(
metadata.get("sender_name").and_then(|v| v.as_str()),
Some("alice")
);
}
#[test]
fn build_send_body_slack() {
let channel = make_channel();
let (method, body) = channel.build_send_body("C456", "hello", Some("1234567.890"));
assert_eq!(method, "chat.postMessage");
assert_eq!(body["channel"], "C456");
assert_eq!(body["text"], "hello");
assert_eq!(body["thread_ts"], "1234567.890");
}
#[tokio::test]
async fn start_processes_events() {
let (tx, rx) = mpsc::channel(64);
let channel =
RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx.clone(), rx);
let mut stream = channel.start().await.unwrap();
tx.send(ChannelEvent {
id: "1".into(),
event_type: "message".into(),
provider: "slack".into(),
provider_scope: "T123".into(),
channel_id: "C456".into(),
sender_id: "U789".into(),
sender_name: Some("alice".into()),
content: Some("hello".into()),
thread_id: None,
raw: serde_json::Value::Null,
timestamp: None,
})
.await
.unwrap();
use futures::StreamExt;
let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next())
.await
.unwrap()
.unwrap();
assert_eq!(msg.content, "hello");
assert_eq!(msg.user_id, "U789");
}
#[tokio::test]
async fn start_skips_non_message_events() {
let (tx, rx) = mpsc::channel(64);
let channel =
RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx.clone(), rx);
let mut stream = channel.start().await.unwrap();
tx.send(ChannelEvent {
id: "1".into(),
event_type: "reaction".into(),
provider: "slack".into(),
provider_scope: "T123".into(),
channel_id: "C456".into(),
sender_id: "U789".into(),
sender_name: None,
content: None,
thread_id: None,
raw: serde_json::Value::Null,
timestamp: None,
})
.await
.unwrap();
tx.send(ChannelEvent {
id: "2".into(),
event_type: "message".into(),
provider: "slack".into(),
provider_scope: "T123".into(),
channel_id: "C456".into(),
sender_id: "U789".into(),
sender_name: None,
content: Some("real message".into()),
thread_id: None,
raw: serde_json::Value::Null,
timestamp: None,
})
.await
.unwrap();
use futures::StreamExt;
let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next())
.await
.unwrap()
.unwrap();
assert_eq!(msg.content, "real message");
}
#[tokio::test]
async fn test_send_status_non_approval_is_noop() {
let channel = make_channel();
let metadata = serde_json::json!({});
let result = channel
.send_status(
StatusUpdate::ToolStarted {
name: "echo".into(),
},
&metadata,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_send_status_approval_non_dm_skips() {
let channel = make_channel();
let metadata = serde_json::json!({
"event_type": "message",
"channel_id": "C456",
"sender_id": "U789",
});
let result = channel
.send_status(
StatusUpdate::ApprovalNeeded {
request_id: "req1".into(),
tool_name: "shell".into(),
description: "run command".into(),
parameters: serde_json::json!({}),
allow_always: true,
},
&metadata,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_send_status_approval_dm_missing_channel_id_errors() {
let channel = make_channel();
let metadata = serde_json::json!({
"event_type": "direct_message",
"sender_id": "U789",
});
let result = channel
.send_status(
StatusUpdate::ApprovalNeeded {
request_id: "req1".into(),
tool_name: "shell".into(),
description: "run command".into(),
parameters: serde_json::json!({}),
allow_always: true,
},
&metadata,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("channel_id"),
"expected channel_id error, got: {err}"
);
}
#[tokio::test]
async fn test_send_status_approval_dm_without_sender_id_is_ok() {
let channel = make_channel();
let metadata = serde_json::json!({
"event_type": "direct_message",
"channel_id": "C456",
});
let result = channel
.send_status(
StatusUpdate::ApprovalNeeded {
request_id: "req1".into(),
tool_name: "shell".into(),
description: "run command".into(),
parameters: serde_json::json!({}),
allow_always: true,
},
&metadata,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
!err.contains("sender_id"),
"sender_id should not be required anymore, got: {err}"
);
}
}