use async_trait::async_trait;
use futures::{FutureExt, SinkExt, StreamExt};
use serde::Deserialize;
use serde_json::{json, Value};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::watch;
use tokio_tungstenite::client_async_tls;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tracing::{debug, error, info, warn};
use crate::bus::{InboundMessage, MediaAttachment, MediaType, MessageBus, OutboundMessage};
use crate::config::DiscordConfig;
use crate::error::{Result, ZeptoError};
use super::{BaseChannelConfig, Channel};
const DISCORD_API_BASE: &str = "https://discord.com/api/v10";
const DISCORD_GATEWAY_URL: &str = "https://discord.com/api/v10/gateway";
const MAX_RECONNECT_DELAY_SECS: u64 = 120;
const BASE_RECONNECT_DELAY_SECS: u64 = 2;
const MAX_RECONNECT_ATTEMPTS: u32 = 10;
const GATEWAY_INTENTS: u64 = (1 << 0) | (1 << 9) | (1 << 12) | (1 << 15);
const DISCORD_MAX_MESSAGE_LENGTH: usize = 2000;
const DISCORD_CHANNEL_TYPE_GUILD_FORUM: u8 = 15;
const DISCORD_CHANNEL_TYPE_GUILD_MEDIA: u8 = 16;
const MAX_PROXY_CONNECT_RESPONSE_BYTES: usize = 8 * 1024;
#[derive(Debug, Deserialize)]
struct GatewayPayload {
op: u8,
#[serde(default)]
d: Option<Value>,
#[serde(default)]
s: Option<u64>,
#[serde(default)]
t: Option<String>,
}
#[derive(Debug, Deserialize)]
struct HelloData {
heartbeat_interval: u64,
}
#[derive(Debug, Deserialize)]
struct DiscordAttachment {
url: String,
#[serde(default)]
content_type: Option<String>,
#[serde(default)]
filename: Option<String>,
#[serde(default)]
size: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct MessageCreateData {
#[serde(default)]
content: String,
channel_id: String,
author: MessageAuthor,
id: String,
#[serde(default)]
attachments: Vec<DiscordAttachment>,
}
#[derive(Debug, Deserialize)]
struct MessageAuthor {
id: String,
#[serde(default)]
bot: Option<bool>,
}
#[derive(Debug, Deserialize)]
struct GatewayResponse {
url: String,
}
#[derive(Debug, Deserialize)]
struct DiscordChannelInfo {
#[serde(rename = "type")]
channel_type: u8,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct DiscordThreadRequest {
name: String,
message_id: Option<String>,
auto_archive_minutes: Option<u16>,
}
pub struct DiscordChannel {
config: DiscordConfig,
base_config: BaseChannelConfig,
bus: Arc<MessageBus>,
running: Arc<AtomicBool>,
shutdown_tx: Option<watch::Sender<bool>>,
http_client: reqwest::Client,
}
impl DiscordChannel {
pub fn new(config: DiscordConfig, bus: Arc<MessageBus>) -> Self {
let base_config = BaseChannelConfig {
name: "discord".to_string(),
allowlist: config.allow_from.clone(),
deny_by_default: config.deny_by_default,
};
Self {
config,
base_config,
bus,
running: Arc::new(AtomicBool::new(false)),
shutdown_tx: None,
http_client: reqwest::Client::new(),
}
}
pub fn discord_config(&self) -> &DiscordConfig {
&self.config
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
async fn fetch_gateway_url(client: &reqwest::Client, token: &str) -> Result<String> {
let response = client
.get(DISCORD_GATEWAY_URL)
.header("Authorization", format!("Bot {}", token))
.send()
.await
.map_err(|e| {
ZeptoError::Channel(format!("Failed to fetch Discord Gateway URL: {}", e))
})?;
let status = response.status();
let body = response.text().await.map_err(|e| {
ZeptoError::Channel(format!("Failed to read Discord Gateway response: {}", e))
})?;
if !status.is_success() {
return Err(ZeptoError::Channel(format!(
"Discord Gateway HTTP {}: {}",
status, body
)));
}
let parsed: GatewayResponse = serde_json::from_str(&body).map_err(|e| {
ZeptoError::Channel(format!("Invalid Discord Gateway response JSON: {}", e))
})?;
let url = parsed.url.trim().to_string();
if url.is_empty() {
return Err(ZeptoError::Channel(
"Discord Gateway response missing URL".to_string(),
));
}
Ok(format!("{}/?v=10&encoding=json", url))
}
fn sanitize_proxy_url(proxy_url: &str) -> String {
match reqwest::Url::parse(proxy_url) {
Ok(mut u) => {
let _ = u.set_password(None);
let _ = u.set_username("");
u.to_string()
}
Err(_) => "<invalid proxy URL>".to_string(),
}
}
fn parse_proxy_url(proxy_url: &str) -> Result<(String, u16, Option<String>)> {
let parsed = reqwest::Url::parse(proxy_url).map_err(|e| {
ZeptoError::Config(format!(
"Invalid Discord gateway proxy URL '{}': {}",
Self::sanitize_proxy_url(proxy_url),
e
))
})?;
if parsed.scheme() != "http" {
return Err(ZeptoError::Config(format!(
"Discord gateway proxy URL must use http:// scheme, got '{}'",
parsed.scheme()
)));
}
let host = parsed.host_str().ok_or_else(|| {
ZeptoError::Config("Discord gateway proxy URL missing host".to_string())
})?;
let port = parsed.port_or_known_default().ok_or_else(|| {
ZeptoError::Config("Discord gateway proxy URL missing port".to_string())
})?;
let auth_header = if parsed.username().is_empty() {
None
} else {
use base64::Engine as _;
let username = Self::percent_decode(parsed.username());
let password = Self::percent_decode(parsed.password().unwrap_or(""));
let creds = format!("{}:{}", username, password);
let encoded = base64::engine::general_purpose::STANDARD.encode(creds.as_bytes());
Some(format!("Proxy-Authorization: Basic {}\r\n", encoded))
};
Ok((host.to_string(), port, auth_header))
}
fn gateway_proxy_from_env(ws_url: &str) -> Option<String> {
let parsed = reqwest::Url::parse(ws_url).ok()?;
let scheme = parsed.scheme();
let prefer_https = matches!(scheme, "wss" | "https");
let candidates: &[&str] = if prefer_https {
&["HTTPS_PROXY", "https_proxy", "HTTP_PROXY", "http_proxy"]
} else {
&["HTTP_PROXY", "http_proxy", "HTTPS_PROXY", "https_proxy"]
};
for key in candidates {
if let Ok(value) = std::env::var(key) {
let trimmed = value.trim();
if trimmed.is_empty() {
continue;
}
match reqwest::Url::parse(trimmed) {
Ok(u) if u.scheme() == "http" => {
return Some(trimmed.to_string());
}
_ => continue,
}
}
}
None
}
fn nono_proxy_auth_header() -> Option<String> {
if let Ok(token) = std::env::var("NONO_PROXY_TOKEN") {
let trimmed = token.trim();
if trimmed.is_empty() {
return None;
}
if trimmed.contains('\r') || trimmed.contains('\n') {
warn!(
"NONO_PROXY_TOKEN contains CR or LF characters; \
ignoring to prevent HTTP header injection"
);
return None;
}
return Some(format!("Proxy-Authorization: Bearer {}\r\n", trimmed));
}
None
}
fn build_connect_request(
target_host: &str,
target_port: u16,
auth_header: Option<&str>,
) -> String {
let authority = if target_host.contains(':') {
format!("[{}]:{}", target_host, target_port)
} else {
format!("{}:{}", target_host, target_port)
};
let mut request = format!(
"CONNECT {authority} HTTP/1.1\r\nHost: {authority}\r\nProxy-Connection: Keep-Alive\r\n"
);
if let Some(auth) = auth_header {
request.push_str(auth);
}
request.push_str("\r\n");
request
}
fn percent_decode(input: &str) -> String {
let bytes = input.as_bytes();
let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
let hi = (bytes[i + 1] as char).to_digit(16);
let lo = (bytes[i + 2] as char).to_digit(16);
if let (Some(hi), Some(lo)) = (hi, lo) {
out.push(((hi << 4) | lo) as u8);
i += 3;
continue;
}
}
out.push(bytes[i]);
i += 1;
}
String::from_utf8_lossy(&out).into_owned()
}
fn parse_connect_response_ok(response: &[u8]) -> Result<bool> {
let text = std::str::from_utf8(response).map_err(|e| {
ZeptoError::Channel(format!(
"HTTP CONNECT proxy response is not valid UTF-8: {}",
e
))
})?;
let status_line = text.lines().next().unwrap_or_default();
let mut parts = status_line.split_whitespace();
let version = parts.next().unwrap_or_default();
let code = parts.next().unwrap_or_default();
if (version == "HTTP/1.1" || version == "HTTP/1.0") && code == "200" {
return Ok(true);
}
Ok(false)
}
async fn connect_via_http_proxy(
ws_url: &str,
proxy_url: &str,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let ws_parsed = reqwest::Url::parse(ws_url).map_err(|e| {
ZeptoError::Channel(format!("Invalid Discord gateway URL '{}': {}", ws_url, e))
})?;
let target_host = ws_parsed
.host_str()
.ok_or_else(|| ZeptoError::Channel("Discord gateway URL missing host".to_string()))?;
let target_port = ws_parsed
.port_or_known_default()
.ok_or_else(|| ZeptoError::Channel("Discord gateway URL missing port".to_string()))?;
let (proxy_host, proxy_port, mut proxy_auth_header) = Self::parse_proxy_url(proxy_url)?;
if proxy_auth_header.is_none() {
proxy_auth_header = Self::nono_proxy_auth_header();
}
let io_timeout = Duration::from_secs(15);
let mut stream = tokio::time::timeout(
io_timeout,
TcpStream::connect((proxy_host.as_str(), proxy_port)),
)
.await
.map_err(|_| ZeptoError::Channel("Timed out connecting to HTTP CONNECT proxy".to_string()))?
.map_err(|e| {
ZeptoError::Channel(format!(
"Failed to connect to Discord HTTP CONNECT proxy '{}:{}': {}",
proxy_host, proxy_port, e
))
})?;
let connect_request =
Self::build_connect_request(target_host, target_port, proxy_auth_header.as_deref());
tokio::time::timeout(io_timeout, stream.write_all(connect_request.as_bytes()))
.await
.map_err(|_| ZeptoError::Channel("Timed out writing HTTP CONNECT request".to_string()))?
.map_err(|e| {
ZeptoError::Channel(format!("Failed to write HTTP CONNECT request: {}", e))
})?;
tokio::time::timeout(io_timeout, stream.flush())
.await
.map_err(|_| {
ZeptoError::Channel("Timed out flushing HTTP CONNECT request".to_string())
})?
.map_err(|e| {
ZeptoError::Channel(format!("Failed to flush HTTP CONNECT request: {}", e))
})?;
let mut buf = Vec::with_capacity(1024);
let mut chunk = [0u8; 1024];
loop {
let n = tokio::time::timeout(io_timeout, stream.read(&mut chunk))
.await
.map_err(|_| {
ZeptoError::Channel("Timed out reading HTTP CONNECT proxy response".to_string())
})?
.map_err(|e| {
ZeptoError::Channel(format!(
"Failed reading HTTP CONNECT proxy response: {}",
e
))
})?;
if n == 0 {
return Err(ZeptoError::Channel(
"HTTP CONNECT proxy closed before sending response headers".to_string(),
));
}
if buf.len() + n > MAX_PROXY_CONNECT_RESPONSE_BYTES {
return Err(ZeptoError::Channel(format!(
"HTTP CONNECT proxy response too large (>{} bytes)",
MAX_PROXY_CONNECT_RESPONSE_BYTES
)));
}
buf.extend_from_slice(&chunk[..n]);
if buf.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
if !Self::parse_connect_response_ok(&buf)? {
let preview = String::from_utf8_lossy(&buf);
let first_line = preview.lines().next().unwrap_or_default();
return Err(ZeptoError::Channel(format!(
"HTTP CONNECT proxy tunnel failed: {}",
first_line
)));
}
let request = ws_url.into_client_request().map_err(|e| {
ZeptoError::Channel(format!(
"Failed to create Discord WebSocket client request for '{}': {}",
ws_url, e
))
})?;
let tls_timeout = Duration::from_secs(30);
let (ws_stream, _) = tokio::time::timeout(tls_timeout, client_async_tls(request, stream))
.await
.map_err(|_| {
ZeptoError::Channel(
"Timed out during TLS/WebSocket handshake over proxy tunnel".to_string(),
)
})?
.map_err(|e| ZeptoError::Channel(format!("Discord WebSocket connect failed: {}", e)))?;
Ok(ws_stream)
}
async fn connect_gateway(ws_url: &str) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
if let Some(proxy_url) = Self::gateway_proxy_from_env(ws_url) {
info!("Discord gateway connecting via HTTP CONNECT proxy from env");
return Self::connect_via_http_proxy(ws_url, &proxy_url).await;
}
let (ws_stream, _) = connect_async(ws_url)
.await
.map_err(|e| ZeptoError::Channel(format!("Discord WebSocket connect failed: {}", e)))?;
Ok(ws_stream)
}
fn build_identify_payload(token: &str) -> String {
json!({
"op": 2,
"d": {
"token": token,
"intents": GATEWAY_INTENTS,
"properties": {
"os": std::env::consts::OS,
"browser": "zeptoclaw",
"device": "zeptoclaw"
}
}
})
.to_string()
}
fn build_heartbeat_payload(sequence: Option<u64>) -> String {
json!({
"op": 1,
"d": sequence
})
.to_string()
}
fn extract_heartbeat_interval(data: &Value) -> Result<u64> {
let hello: HelloData = serde_json::from_value(data.clone())
.map_err(|e| ZeptoError::Channel(format!("Invalid Discord HELLO payload: {}", e)))?;
Ok(hello.heartbeat_interval)
}
fn parse_message_create(
data: &Value,
allowlist: &[String],
deny_by_default: bool,
) -> Option<InboundMessage> {
let msg: MessageCreateData = serde_json::from_value(data.clone()).ok()?;
if msg.author.bot.unwrap_or(false) {
return None;
}
let content = msg.content.trim().to_string();
if content.is_empty() {
return None;
}
let sender_id = msg.author.id.trim().to_string();
if sender_id.is_empty() {
return None;
}
let allowed = if allowlist.is_empty() {
!deny_by_default
} else {
allowlist.contains(&sender_id)
};
if !allowed {
info!(
"Discord: user {} not in allowlist, ignoring message",
sender_id
);
return None;
}
let channel_id = msg.channel_id.trim().to_string();
if channel_id.is_empty() {
return None;
}
let inbound = InboundMessage::new("discord", &sender_id, &channel_id, &content)
.with_metadata("discord_message_id", &msg.id);
Some(inbound)
}
fn backoff_delay(attempt: u32) -> Duration {
let delay_secs = BASE_RECONNECT_DELAY_SECS
.saturating_mul(2u64.saturating_pow(attempt))
.min(MAX_RECONNECT_DELAY_SECS);
Duration::from_secs(delay_secs)
}
fn build_send_payload(msg: &OutboundMessage) -> Result<Value> {
let channel_id = msg.chat_id.trim();
if channel_id.is_empty() {
return Err(ZeptoError::Channel(
"Discord channel ID cannot be empty".to_string(),
));
}
let content = if msg.content.chars().count() > DISCORD_MAX_MESSAGE_LENGTH {
crate::utils::string::preview(
&msg.content,
DISCORD_MAX_MESSAGE_LENGTH.saturating_sub(3),
)
} else {
msg.content.to_string()
};
let mut payload = json!({ "content": content });
if let Some(ref reply_id) = msg.reply_to {
if let Some(map) = payload.as_object_mut() {
map.insert(
"message_reference".to_string(),
json!({ "message_id": reply_id }),
);
}
}
Ok(payload)
}
fn parse_discord_thread_request(msg: &OutboundMessage) -> Result<Option<DiscordThreadRequest>> {
let thread_name = msg
.metadata
.get("discord_thread_name")
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(str::to_string);
let message_id = msg
.metadata
.get("discord_thread_message_id")
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(str::to_string);
let auto_archive_minutes = match msg
.metadata
.get("discord_thread_auto_archive_minutes")
.map(|s| s.trim())
.filter(|s| !s.is_empty())
{
Some(raw) => Some(raw.parse::<u16>().map_err(|_| {
ZeptoError::Channel(format!(
"Invalid discord_thread_auto_archive_minutes '{}'",
raw
))
})?),
None => None,
};
if thread_name.is_none() && message_id.is_none() && auto_archive_minutes.is_none() {
return Ok(None);
}
let name = thread_name.ok_or_else(|| {
ZeptoError::Channel("Missing discord_thread_name for Discord thread create".to_string())
})?;
Ok(Some(DiscordThreadRequest {
name,
message_id,
auto_archive_minutes,
}))
}
fn build_create_thread_payload(
req: &DiscordThreadRequest,
content: &str,
forum_like: bool,
) -> Value {
let mut payload = json!({ "name": req.name });
if let Some(minutes) = req.auto_archive_minutes {
if let Some(map) = payload.as_object_mut() {
map.insert("auto_archive_duration".to_string(), json!(minutes));
}
}
if forum_like {
let starter = content.trim();
let starter_content = if starter.is_empty() {
&req.name
} else {
starter
};
if let Some(map) = payload.as_object_mut() {
map.insert("message".to_string(), json!({ "content": starter_content }));
}
}
payload
}
async fn fetch_channel_type(&self, token: &str, channel_id: &str) -> Result<Option<u8>> {
let url = format!("{}/channels/{}", DISCORD_API_BASE, channel_id);
let response = self
.http_client
.get(&url)
.header("Authorization", format!("Bot {}", token))
.send()
.await
.map_err(|e| {
ZeptoError::Channel(format!(
"Failed to fetch Discord channel metadata for thread create: {}",
e
))
})?;
let status = response.status();
let body = response.text().await.map_err(|e| {
ZeptoError::Channel(format!("Failed to read Discord channel response: {}", e))
})?;
if !status.is_success() {
warn!(
"Discord channel lookup returned HTTP {} while creating thread: {}",
status, body
);
return Ok(None);
}
let parsed: DiscordChannelInfo = match serde_json::from_str(&body) {
Ok(p) => p,
Err(e) => {
warn!("Discord channel lookup JSON parse failed: {}", e);
return Ok(None);
}
};
Ok(Some(parsed.channel_type))
}
async fn create_thread(
&self,
token: &str,
channel_id: &str,
req: &DiscordThreadRequest,
content: &str,
) -> Result<()> {
let is_forum_like = if req.message_id.is_none() {
matches!(
self.fetch_channel_type(token, channel_id).await?,
Some(channel_type)
if channel_type == DISCORD_CHANNEL_TYPE_GUILD_FORUM
|| channel_type == DISCORD_CHANNEL_TYPE_GUILD_MEDIA
)
} else {
false
};
let payload = Self::build_create_thread_payload(req, content, is_forum_like);
let url = if let Some(message_id) = req.message_id.as_deref() {
format!(
"{}/channels/{}/messages/{}/threads",
DISCORD_API_BASE, channel_id, message_id
)
} else {
format!("{}/channels/{}/threads", DISCORD_API_BASE, channel_id)
};
let response = self
.http_client
.post(&url)
.header("Authorization", format!("Bot {}", token))
.json(&payload)
.send()
.await
.map_err(|e| {
ZeptoError::Channel(format!("Failed to call Discord thread create API: {}", e))
})?;
let status = response.status();
let body = response.text().await.map_err(|e| {
ZeptoError::Channel(format!(
"Failed to read Discord thread create response: {}",
e
))
})?;
if !status.is_success() {
return Err(ZeptoError::Channel(format!(
"Discord thread create API returned HTTP {}: {}",
status, body
)));
}
Ok(())
}
async fn run_gateway_loop(
client: reqwest::Client,
token: String,
bus: Arc<MessageBus>,
allowlist: Vec<String>,
deny_by_default: bool,
mut shutdown_rx: watch::Receiver<bool>,
) {
let mut reconnect_attempt: u32 = 0;
loop {
if *shutdown_rx.borrow() {
info!("Discord gateway shutdown requested");
return;
}
let ws_url = tokio::select! {
_ = shutdown_rx.changed() => {
info!("Discord gateway shutdown requested");
return;
}
result = Self::fetch_gateway_url(&client, &token) => {
match result {
Ok(url) => url,
Err(e) => {
warn!("Discord: failed to fetch gateway URL: {}", e);
let delay = Self::backoff_delay(reconnect_attempt);
reconnect_attempt =
(reconnect_attempt + 1).min(MAX_RECONNECT_ATTEMPTS);
tokio::select! {
_ = shutdown_rx.changed() => return,
_ = tokio::time::sleep(delay) => continue,
}
}
}
}
};
info!("Discord gateway URL discovered: {}", ws_url);
let ws_stream = tokio::select! {
_ = shutdown_rx.changed() => {
info!("Discord gateway shutdown requested");
return;
}
result = Self::connect_gateway(&ws_url) => {
match result {
Ok(stream) => stream,
Err(e) => {
warn!("Discord: WebSocket connect failed: {}", e);
let delay = Self::backoff_delay(reconnect_attempt);
reconnect_attempt =
(reconnect_attempt + 1).min(MAX_RECONNECT_ATTEMPTS);
tokio::select! {
_ = shutdown_rx.changed() => return,
_ = tokio::time::sleep(delay) => continue,
}
}
}
}
};
info!("Discord gateway WebSocket connected");
reconnect_attempt = 0;
let (mut ws_writer, mut ws_reader) = ws_stream.split();
let heartbeat_interval = loop {
let next = tokio::select! {
_ = shutdown_rx.changed() => {
info!("Discord gateway shutdown requested");
return;
}
msg = ws_reader.next() => msg,
};
match next {
Some(Ok(WsMessage::Text(raw))) => {
match serde_json::from_str::<GatewayPayload>(&raw) {
Ok(payload) if payload.op == 10 => {
if let Some(ref data) = payload.d {
match Self::extract_heartbeat_interval(data) {
Ok(interval) => {
debug!(
"Discord HELLO: heartbeat_interval = {}ms",
interval
);
break interval;
}
Err(e) => {
warn!("Discord: invalid HELLO data: {}", e);
break 41250; }
}
} else {
warn!("Discord: HELLO without data, using default interval");
break 41250;
}
}
Ok(_) => {
debug!("Discord: ignoring pre-HELLO payload");
}
Err(e) => {
debug!("Discord: failed to parse pre-HELLO payload: {}", e);
}
}
}
Some(Ok(_)) => {}
Some(Err(e)) => {
warn!("Discord: WebSocket error waiting for HELLO: {}", e);
break 0; }
None => {
warn!("Discord: WebSocket closed before HELLO");
break 0;
}
}
};
if heartbeat_interval == 0 {
let delay = Self::backoff_delay(reconnect_attempt);
reconnect_attempt = (reconnect_attempt + 1).min(MAX_RECONNECT_ATTEMPTS);
tokio::select! {
_ = shutdown_rx.changed() => return,
_ = tokio::time::sleep(delay) => continue,
}
}
let identify = Self::build_identify_payload(&token);
if let Err(e) = ws_writer.send(WsMessage::Text(identify.into())).await {
warn!("Discord: failed to send IDENTIFY: {}", e);
let delay = Self::backoff_delay(reconnect_attempt);
reconnect_attempt = (reconnect_attempt + 1).min(MAX_RECONNECT_ATTEMPTS);
tokio::select! {
_ = shutdown_rx.changed() => return,
_ = tokio::time::sleep(delay) => continue,
}
}
let sequence = Arc::new(std::sync::atomic::AtomicU64::new(0));
let sequence_valid = Arc::new(AtomicBool::new(false));
let heartbeat_shutdown = shutdown_rx.clone();
let seq_clone = Arc::clone(&sequence);
let seq_valid_clone = Arc::clone(&sequence_valid);
let (heartbeat_tx, mut heartbeat_rx) = tokio::sync::mpsc::channel::<String>(16);
tokio::spawn({
let mut shutdown = heartbeat_shutdown;
async move {
let task_result = std::panic::AssertUnwindSafe(async move {
let interval = Duration::from_millis(heartbeat_interval);
loop {
tokio::select! {
_ = shutdown.changed() => {
debug!("Discord heartbeat task shutting down");
return;
}
_ = tokio::time::sleep(interval) => {
let s = if seq_valid_clone.load(Ordering::SeqCst) {
Some(seq_clone.load(Ordering::SeqCst))
} else {
None
};
let payload = Self::build_heartbeat_payload(s);
if heartbeat_tx.send(payload).await.is_err() {
debug!("Discord heartbeat channel closed");
return;
}
}
}
}
})
.catch_unwind()
.await;
if task_result.is_err() {
error!("Discord heartbeat task panicked");
}
}
});
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
info!("Discord gateway shutdown requested");
return;
}
hb = heartbeat_rx.recv() => {
match hb {
Some(payload) => {
if let Err(e) = ws_writer.send(WsMessage::Text(payload.into())).await {
warn!("Discord: heartbeat send failed: {}", e);
break;
}
}
None => {
debug!("Discord heartbeat channel closed");
break;
}
}
}
msg = ws_reader.next() => {
match msg {
Some(Ok(WsMessage::Text(raw))) => {
match serde_json::from_str::<GatewayPayload>(&raw) {
Ok(payload) => {
if let Some(s) = payload.s {
sequence.store(s, Ordering::SeqCst);
sequence_valid.store(true, Ordering::SeqCst);
}
match payload.op {
0 => {
if let Some(event_name) = payload.t.as_deref() {
if event_name == "MESSAGE_CREATE" {
if let Some(ref data) = payload.d {
if let Some(mut inbound) =
Self::parse_message_create(data, &allowlist, deny_by_default)
{
if let Ok(msg_data) = serde_json::from_value::<MessageCreateData>(data.clone()) {
for att in &msg_data.attachments {
if let Some(ref ct) = att.content_type {
if ct.starts_with("image/")
&& att.size.is_none_or(|s| s <= 20 * 1024 * 1024)
{
match client.get(&att.url).send().await {
Ok(resp) => {
if let Ok(bytes) = resp.bytes().await {
let mut media = MediaAttachment::new(MediaType::Image)
.with_data(bytes.to_vec())
.with_mime_type(ct);
if let Some(ref name) = att.filename {
media = media.with_filename(name);
}
inbound = inbound.with_media(media);
}
}
Err(e) => warn!("Failed to download Discord attachment: {}", e),
}
}
}
}
}
if let Err(e) =
bus.publish_inbound(inbound).await
{
error!(
"Failed to publish Discord inbound message: {}",
e
);
}
}
}
} else if event_name == "READY" {
info!("Discord gateway READY");
} else {
debug!("Discord: ignoring event {}", event_name);
}
}
}
1 => {
let s = if sequence_valid.load(Ordering::SeqCst) {
Some(sequence.load(Ordering::SeqCst))
} else {
None
};
let hb = Self::build_heartbeat_payload(s);
if let Err(e) = ws_writer.send(WsMessage::Text(hb.into())).await {
warn!("Discord: heartbeat response send failed: {}", e);
break;
}
}
7 => {
info!("Discord: server requested reconnect");
break;
}
9 => {
warn!("Discord: invalid session, reconnecting");
break;
}
11 => {
debug!("Discord: heartbeat ACK received");
}
_ => {
debug!("Discord: unhandled opcode {}", payload.op);
}
}
}
Err(e) => {
debug!("Discord: failed to parse gateway payload: {}", e);
}
}
}
Some(Ok(WsMessage::Ping(payload))) => {
if let Err(e) = ws_writer.send(WsMessage::Pong(payload)).await {
warn!("Discord: pong send failed: {}", e);
break;
}
}
Some(Ok(WsMessage::Close(frame))) => {
info!("Discord: WebSocket closed by server: {:?}", frame);
break;
}
Some(Ok(_)) => {}
Some(Err(e)) => {
warn!("Discord: WebSocket stream error: {}", e);
break;
}
None => {
warn!("Discord: WebSocket stream ended");
break;
}
}
}
}
}
let delay = Self::backoff_delay(reconnect_attempt);
reconnect_attempt = (reconnect_attempt + 1).min(MAX_RECONNECT_ATTEMPTS);
info!("Discord: reconnecting in {} seconds", delay.as_secs());
tokio::select! {
_ = shutdown_rx.changed() => return,
_ = tokio::time::sleep(delay) => {},
}
}
}
}
#[async_trait]
impl Channel for DiscordChannel {
fn name(&self) -> &str {
"discord"
}
async fn start(&mut self) -> Result<()> {
if self.running.swap(true, Ordering::SeqCst) {
info!("Discord channel already running");
return Ok(());
}
if !self.config.enabled {
warn!("Discord channel is disabled in configuration");
self.running.store(false, Ordering::SeqCst);
return Ok(());
}
let token = self.config.token.trim().to_string();
if token.is_empty() {
self.running.store(false, Ordering::SeqCst);
return Err(ZeptoError::Config("Discord bot token is empty".to_string()));
}
let (shutdown_tx, shutdown_rx) = watch::channel(false);
self.shutdown_tx = Some(shutdown_tx);
info!("Starting Discord channel with Gateway WebSocket");
let running_clone = Arc::clone(&self.running);
let http_client = self.http_client.clone();
let bus = Arc::clone(&self.bus);
let allow_from = self.config.allow_from.clone();
let deny_by_default = self.config.deny_by_default;
tokio::spawn(async move {
let task_result = std::panic::AssertUnwindSafe(async move {
Self::run_gateway_loop(
http_client,
token,
bus,
allow_from,
deny_by_default,
shutdown_rx,
)
.await;
})
.catch_unwind()
.await;
if task_result.is_err() {
error!("Discord gateway task panicked");
}
running_clone.store(false, Ordering::SeqCst);
});
Ok(())
}
async fn stop(&mut self) -> Result<()> {
if !self.running.swap(false, Ordering::SeqCst) {
info!("Discord channel already stopped");
return Ok(());
}
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(true);
}
info!("Discord channel stopped");
Ok(())
}
async fn send(&self, msg: OutboundMessage) -> Result<()> {
if !self.running.load(Ordering::SeqCst) {
return Err(ZeptoError::Channel(
"Discord channel not running".to_string(),
));
}
let token = self.config.token.trim();
if token.is_empty() {
return Err(ZeptoError::Config("Discord bot token is empty".to_string()));
}
let channel_id = msg.chat_id.trim();
if channel_id.is_empty() {
return Err(ZeptoError::Channel(
"Discord channel ID cannot be empty".to_string(),
));
}
if let Some(thread_req) = Self::parse_discord_thread_request(&msg)? {
self.create_thread(token, channel_id, &thread_req, &msg.content)
.await?;
info!("Discord: thread created successfully");
return Ok(());
}
let payload = Self::build_send_payload(&msg)?;
let url = format!("{}/channels/{}/messages", DISCORD_API_BASE, channel_id);
let response = self
.http_client
.post(&url)
.header("Authorization", format!("Bot {}", token))
.json(&payload)
.send()
.await
.map_err(|e| ZeptoError::Channel(format!("Failed to call Discord API: {}", e)))?;
let status = response.status();
let body = response.text().await.map_err(|e| {
ZeptoError::Channel(format!("Failed to read Discord API response: {}", e))
})?;
if !status.is_success() {
return Err(ZeptoError::Channel(format!(
"Discord API returned HTTP {}: {}",
status, body
)));
}
info!("Discord: message sent successfully");
Ok(())
}
fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
fn is_allowed(&self, user_id: &str) -> bool {
self.base_config.is_allowed(user_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_bus() -> Arc<MessageBus> {
Arc::new(MessageBus::new())
}
fn test_config() -> DiscordConfig {
DiscordConfig {
enabled: true,
token: "test-bot-token".to_string(),
allow_from: vec!["123456789".to_string()],
..Default::default()
}
}
#[test]
fn test_channel_name() {
let channel = DiscordChannel::new(test_config(), test_bus());
assert_eq!(channel.name(), "discord");
}
#[test]
fn test_config_initialization() {
let config = DiscordConfig {
enabled: true,
token: "my-token".to_string(),
allow_from: vec!["U1".to_string(), "U2".to_string()],
..Default::default()
};
let channel = DiscordChannel::new(config, test_bus());
assert!(channel.is_enabled());
assert_eq!(channel.discord_config().token, "my-token");
assert_eq!(channel.discord_config().allow_from.len(), 2);
assert!(!channel.is_running());
}
#[test]
fn test_is_allowed_delegation() {
let channel = DiscordChannel::new(test_config(), test_bus());
assert!(channel.is_allowed("123456789"));
assert!(!channel.is_allowed("999999999"));
}
#[test]
fn test_is_allowed_empty_allowlist() {
let config = DiscordConfig {
enabled: true,
token: "tok".to_string(),
allow_from: vec![],
..Default::default()
};
let channel = DiscordChannel::new(config, test_bus());
assert!(channel.is_allowed("anyone"));
assert!(channel.is_allowed("literally_anyone"));
}
#[test]
fn test_gateway_url_formatting() {
let base = "wss://gateway.discord.gg";
let formatted = format!("{}/?v=10&encoding=json", base);
assert_eq!(formatted, "wss://gateway.discord.gg/?v=10&encoding=json");
}
#[test]
fn test_gateway_response_deserialization() {
let json = r#"{"url": "wss://gateway.discord.gg"}"#;
let resp: GatewayResponse = serde_json::from_str(json).expect("should parse");
assert_eq!(resp.url, "wss://gateway.discord.gg");
}
#[test]
fn test_message_create_deserialization() {
let data = json!({
"id": "msg-001",
"content": "Hello from Discord!",
"channel_id": "ch-100",
"author": {
"id": "user-42",
"bot": false
}
});
let inbound = DiscordChannel::parse_message_create(&data, &[], false);
assert!(inbound.is_some());
let msg = inbound.unwrap();
assert_eq!(msg.channel, "discord");
assert_eq!(msg.sender_id, "user-42");
assert_eq!(msg.chat_id, "ch-100");
assert_eq!(msg.content, "Hello from Discord!");
assert_eq!(
msg.metadata.get("discord_message_id"),
Some(&"msg-001".to_string())
);
}
#[test]
fn test_message_create_with_allowlist() {
let data = json!({
"id": "msg-002",
"content": "test",
"channel_id": "ch-200",
"author": { "id": "allowed-user", "bot": false }
});
let allowed =
DiscordChannel::parse_message_create(&data, &["allowed-user".to_string()], false);
assert!(allowed.is_some());
let denied =
DiscordChannel::parse_message_create(&data, &["someone-else".to_string()], false);
assert!(denied.is_none());
}
#[test]
fn test_heartbeat_interval_extraction() {
let data = json!({ "heartbeat_interval": 41250 });
let interval = DiscordChannel::extract_heartbeat_interval(&data).expect("should extract");
assert_eq!(interval, 41250);
}
#[test]
fn test_heartbeat_interval_extraction_invalid() {
let data = json!({ "something_else": 123 });
let result = DiscordChannel::extract_heartbeat_interval(&data);
assert!(result.is_err());
}
#[test]
fn test_bot_message_ignored() {
let data = json!({
"id": "msg-003",
"content": "I am a bot",
"channel_id": "ch-300",
"author": { "id": "bot-user", "bot": true }
});
let result = DiscordChannel::parse_message_create(&data, &[], false);
assert!(result.is_none());
}
#[test]
fn test_empty_content_ignored() {
let data = json!({
"id": "msg-004",
"content": " ",
"channel_id": "ch-400",
"author": { "id": "user-1", "bot": false }
});
let result = DiscordChannel::parse_message_create(&data, &[], false);
assert!(result.is_none());
}
#[test]
fn test_missing_bot_field_treated_as_human() {
let data = json!({
"id": "msg-005",
"content": "No bot field",
"channel_id": "ch-500",
"author": { "id": "user-2" }
});
let result = DiscordChannel::parse_message_create(&data, &[], false);
assert!(result.is_some());
}
#[test]
fn test_outbound_message_payload() {
let msg = OutboundMessage::new("discord", "ch-100", "Hello back!");
let payload = DiscordChannel::build_send_payload(&msg).expect("should build payload");
assert_eq!(payload["content"], "Hello back!");
assert!(payload.get("message_reference").is_none());
}
#[test]
fn test_outbound_message_with_reply() {
let msg =
OutboundMessage::new("discord", "ch-100", "reply text").with_reply("original-msg-id");
let payload = DiscordChannel::build_send_payload(&msg).expect("should build payload");
assert_eq!(payload["content"], "reply text");
assert_eq!(
payload["message_reference"]["message_id"],
"original-msg-id"
);
}
#[test]
fn test_outbound_empty_channel_id() {
let msg = OutboundMessage::new("discord", " ", "test");
let result = DiscordChannel::build_send_payload(&msg);
assert!(result.is_err());
}
#[test]
fn test_outbound_message_truncation() {
let long_content = "x".repeat(2500);
let msg = OutboundMessage::new("discord", "ch-100", &long_content);
let payload = DiscordChannel::build_send_payload(&msg).expect("should build payload");
let content = payload["content"].as_str().unwrap();
assert!(content.len() <= DISCORD_MAX_MESSAGE_LENGTH);
assert!(content.ends_with("..."));
}
#[test]
fn test_parse_discord_thread_request_none() {
let msg = OutboundMessage::new("discord", "ch-100", "hello");
let req = DiscordChannel::parse_discord_thread_request(&msg).expect("parse should succeed");
assert!(req.is_none());
}
#[test]
fn test_parse_discord_thread_request_success() {
let msg = OutboundMessage::new("discord", "ch-100", "hello")
.with_metadata("discord_thread_name", "Ops")
.with_metadata("discord_thread_auto_archive_minutes", "60");
let req = DiscordChannel::parse_discord_thread_request(&msg)
.expect("parse should succeed")
.expect("thread request should exist");
assert_eq!(req.name, "Ops");
assert_eq!(req.auto_archive_minutes, Some(60));
assert_eq!(req.message_id, None);
}
#[test]
fn test_parse_discord_thread_request_requires_name() {
let msg = OutboundMessage::new("discord", "ch-100", "hello")
.with_metadata("discord_thread_message_id", "m1");
let err = DiscordChannel::parse_discord_thread_request(&msg).unwrap_err();
assert!(err
.to_string()
.contains("Missing discord_thread_name for Discord thread create"));
}
#[test]
fn test_build_create_thread_payload_forum_like_uses_starter_message() {
let req = DiscordThreadRequest {
name: "Daily".to_string(),
message_id: None,
auto_archive_minutes: Some(1440),
};
let payload = DiscordChannel::build_create_thread_payload(&req, " ", true);
assert_eq!(payload["name"], "Daily");
assert_eq!(payload["auto_archive_duration"], 1440);
assert_eq!(payload["message"]["content"], "Daily");
}
#[tokio::test]
async fn test_running_state_default() {
let channel = DiscordChannel::new(test_config(), test_bus());
assert!(!channel.is_running());
}
#[tokio::test]
async fn test_start_without_token() {
let config = DiscordConfig {
enabled: true,
token: String::new(),
allow_from: vec![],
..Default::default()
};
let mut channel = DiscordChannel::new(config, test_bus());
let result = channel.start().await;
assert!(result.is_err());
assert!(!channel.is_running());
}
#[tokio::test]
async fn test_start_disabled() {
let config = DiscordConfig {
enabled: false,
token: "some-token".to_string(),
allow_from: vec![],
..Default::default()
};
let mut channel = DiscordChannel::new(config, test_bus());
let result = channel.start().await;
assert!(result.is_ok());
assert!(!channel.is_running());
}
#[tokio::test]
async fn test_stop_not_running() {
let mut channel = DiscordChannel::new(test_config(), test_bus());
let result = channel.stop().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_send_not_running() {
let channel = DiscordChannel::new(test_config(), test_bus());
let msg = OutboundMessage::new("discord", "ch-100", "Hello");
let result = channel.send(msg).await;
assert!(result.is_err());
}
#[test]
fn test_backoff_delay_increases_exponentially() {
let d0 = DiscordChannel::backoff_delay(0);
let d1 = DiscordChannel::backoff_delay(1);
let d2 = DiscordChannel::backoff_delay(2);
let d3 = DiscordChannel::backoff_delay(3);
assert_eq!(d0, Duration::from_secs(2)); assert_eq!(d1, Duration::from_secs(4)); assert_eq!(d2, Duration::from_secs(8)); assert_eq!(d3, Duration::from_secs(16)); }
#[test]
fn test_backoff_delay_caps_at_max() {
let d_high = DiscordChannel::backoff_delay(20);
assert_eq!(d_high, Duration::from_secs(MAX_RECONNECT_DELAY_SECS));
}
#[test]
fn test_backoff_delay_does_not_overflow() {
let d = DiscordChannel::backoff_delay(u32::MAX);
assert_eq!(d, Duration::from_secs(MAX_RECONNECT_DELAY_SECS));
}
#[test]
fn test_parse_proxy_url_http_with_auth() {
let (host, port, auth) =
DiscordChannel::parse_proxy_url("http://alice:secret@127.0.0.1:8080")
.expect("proxy should parse");
assert_eq!(host, "127.0.0.1");
assert_eq!(port, 8080);
assert!(auth.is_some());
let auth_line = auth.expect("auth header expected");
assert!(auth_line.starts_with("Proxy-Authorization: Basic "));
}
#[test]
fn test_parse_proxy_url_rejects_non_http_scheme() {
let err = DiscordChannel::parse_proxy_url("https://127.0.0.1:8080")
.expect_err("https proxy must be rejected for CONNECT transport");
assert!(err
.to_string()
.contains("Discord gateway proxy URL must use http:// scheme"));
}
#[test]
fn test_parse_connect_response_ok() {
let resp = b"HTTP/1.1 200 Connection Established\r\nProxy-Agent: test\r\n\r\n";
let ok = DiscordChannel::parse_connect_response_ok(resp).expect("parse should succeed");
assert!(ok);
}
#[test]
fn test_parse_connect_response_not_ok() {
let resp = b"HTTP/1.1 407 Proxy Authentication Required\r\n\r\n";
let ok = DiscordChannel::parse_connect_response_ok(resp).expect("parse should succeed");
assert!(!ok);
}
#[test]
fn test_parse_connect_response_rejects_2000() {
let resp = b"HTTP/1.1 2000 OK\r\n\r\n";
let ok = DiscordChannel::parse_connect_response_ok(resp).expect("parse should succeed");
assert!(!ok, "status code 2000 must not be accepted as a 200");
}
#[test]
fn test_build_connect_request_ipv6_brackets() {
let req = DiscordChannel::build_connect_request("::1", 443, None);
assert!(
req.starts_with("CONNECT [::1]:443 HTTP/1.1"),
"IPv6 literal must be bracketed in CONNECT request-target: {}",
req.lines().next().unwrap_or_default()
);
assert!(
req.contains("Host: [::1]:443"),
"Host header must also bracket IPv6"
);
}
#[test]
fn test_build_connect_request_ipv4_no_brackets() {
let req = DiscordChannel::build_connect_request("127.0.0.1", 80, None);
assert!(req.starts_with("CONNECT 127.0.0.1:80 HTTP/1.1"));
}
#[test]
fn test_parse_proxy_url_percent_encoded_credentials() {
let (_, _, auth) =
DiscordChannel::parse_proxy_url("http://user%40corp:p%40ss@127.0.0.1:3128")
.expect("should parse");
let auth_line = auth.expect("auth header expected");
use base64::Engine as _;
let b64 = auth_line
.trim_start_matches("Proxy-Authorization: Basic ")
.trim_end_matches("\r\n");
let decoded = base64::engine::general_purpose::STANDARD
.decode(b64)
.expect("valid base64");
let creds = String::from_utf8(decoded).expect("valid utf8");
assert_eq!(creds, "user@corp:p@ss");
}
#[test]
fn test_gateway_proxy_env_candidate_fallback() {
let https_key = "HTTPS_PROXY";
let http_key = "HTTP_PROXY";
let saved_https = std::env::var(https_key).ok();
let saved_http = std::env::var(http_key).ok();
unsafe {
std::env::set_var(https_key, "https://proxy.example.com:3128");
std::env::set_var(http_key, "http://proxy.example.com:3128");
}
let result =
DiscordChannel::gateway_proxy_from_env("wss://gateway.discord.gg/?v=10&encoding=json");
unsafe {
match &saved_https {
Some(v) => std::env::set_var(https_key, v),
None => std::env::remove_var(https_key),
}
match &saved_http {
Some(v) => std::env::set_var(http_key, v),
None => std::env::remove_var(http_key),
}
}
assert_eq!(
result.as_deref(),
Some("http://proxy.example.com:3128"),
"https:// HTTPS_PROXY must be skipped; http:// HTTP_PROXY must be returned"
);
}
#[test]
fn test_identify_payload_structure() {
let payload_str = DiscordChannel::build_identify_payload("my-token");
let payload: Value = serde_json::from_str(&payload_str).expect("valid JSON");
assert_eq!(payload["op"], 2);
assert_eq!(payload["d"]["token"], "my-token");
assert_eq!(payload["d"]["intents"], GATEWAY_INTENTS);
assert_eq!(payload["d"]["properties"]["browser"], "zeptoclaw");
}
#[test]
fn test_heartbeat_payload_with_sequence() {
let payload_str = DiscordChannel::build_heartbeat_payload(Some(42));
let payload: Value = serde_json::from_str(&payload_str).expect("valid JSON");
assert_eq!(payload["op"], 1);
assert_eq!(payload["d"], 42);
}
#[test]
fn test_heartbeat_payload_without_sequence() {
let payload_str = DiscordChannel::build_heartbeat_payload(None);
let payload: Value = serde_json::from_str(&payload_str).expect("valid JSON");
assert_eq!(payload["op"], 1);
assert!(payload["d"].is_null());
}
#[test]
fn test_gateway_payload_deserialization_dispatch() {
let raw = r#"{
"op": 0,
"s": 5,
"t": "MESSAGE_CREATE",
"d": {
"id": "msg-100",
"content": "test message",
"channel_id": "ch-999",
"author": { "id": "u-1", "bot": false }
}
}"#;
let payload: GatewayPayload = serde_json::from_str(raw).expect("should deserialize");
assert_eq!(payload.op, 0);
assert_eq!(payload.s, Some(5));
assert_eq!(payload.t, Some("MESSAGE_CREATE".to_string()));
assert!(payload.d.is_some());
}
#[test]
fn test_gateway_payload_deserialization_hello() {
let raw = r#"{
"op": 10,
"d": { "heartbeat_interval": 41250 }
}"#;
let payload: GatewayPayload = serde_json::from_str(raw).expect("should deserialize");
assert_eq!(payload.op, 10);
assert!(payload.s.is_none());
assert!(payload.t.is_none());
let interval = DiscordChannel::extract_heartbeat_interval(payload.d.as_ref().unwrap())
.expect("should extract");
assert_eq!(interval, 41250);
}
#[test]
fn test_discord_config_deserialize_defaults() {
let json = r#"{ "token": "bot-abc-123" }"#;
let config: DiscordConfig = serde_json::from_str(json).expect("should parse");
assert!(!config.enabled); assert_eq!(config.token, "bot-abc-123");
assert!(config.allow_from.is_empty());
}
#[test]
fn test_discord_config_deserialize_full() {
let json = r#"{
"enabled": true,
"token": "tok-full",
"allow_from": ["111", "222", "333"]
}"#;
let config: DiscordConfig = serde_json::from_str(json).expect("should parse");
assert!(config.enabled);
assert_eq!(config.token, "tok-full");
assert_eq!(config.allow_from, vec!["111", "222", "333"]);
}
#[test]
fn test_discord_config_default_trait() {
let config = DiscordConfig::default();
assert!(!config.enabled);
assert!(config.token.is_empty());
assert!(config.allow_from.is_empty());
}
#[test]
fn test_gateway_payload_minimal_fields() {
let raw = r#"{ "op": 11 }"#;
let payload: GatewayPayload = serde_json::from_str(raw).expect("should parse");
assert_eq!(payload.op, 11);
assert!(payload.d.is_none());
assert!(payload.s.is_none());
assert!(payload.t.is_none());
}
#[test]
fn test_gateway_payload_reconnect_opcode() {
let raw = r#"{ "op": 7, "d": null }"#;
let payload: GatewayPayload = serde_json::from_str(raw).expect("should parse");
assert_eq!(payload.op, 7);
}
#[test]
fn test_gateway_payload_invalid_session_opcode() {
let raw = r#"{ "op": 9, "d": false }"#;
let payload: GatewayPayload = serde_json::from_str(raw).expect("should parse");
assert_eq!(payload.op, 9);
assert_eq!(payload.d, Some(json!(false)));
}
#[test]
fn test_message_create_empty_author_id() {
let data = json!({
"id": "msg-edge-1",
"content": "valid content",
"channel_id": "ch-100",
"author": { "id": " ", "bot": false }
});
let result = DiscordChannel::parse_message_create(&data, &[], false);
assert!(result.is_none());
}
#[test]
fn test_message_create_missing_content_field() {
let data = json!({
"id": "msg-edge-2",
"channel_id": "ch-200",
"author": { "id": "user-42", "bot": false }
});
let result = DiscordChannel::parse_message_create(&data, &[], false);
assert!(result.is_none());
}
#[test]
fn test_message_create_empty_channel_id() {
let data = json!({
"id": "msg-edge-3",
"content": "hello",
"channel_id": " ",
"author": { "id": "user-42", "bot": false }
});
let result = DiscordChannel::parse_message_create(&data, &[], false);
assert!(result.is_none());
}
#[test]
fn test_message_create_content_trimmed() {
let data = json!({
"id": "msg-trim",
"content": " padded message ",
"channel_id": "ch-100",
"author": { "id": "user-1" }
});
let inbound = DiscordChannel::parse_message_create(&data, &[], false).unwrap();
assert_eq!(inbound.content, "padded message");
}
#[test]
fn test_hello_data_deserialization() {
let data: HelloData = serde_json::from_value(json!({
"heartbeat_interval": 45000
}))
.expect("should parse");
assert_eq!(data.heartbeat_interval, 45000);
}
#[test]
fn test_hello_data_extra_fields_ignored() {
let data: HelloData = serde_json::from_value(json!({
"heartbeat_interval": 30000,
"_trace": ["gateway-1"]
}))
.expect("should parse with extra fields");
assert_eq!(data.heartbeat_interval, 30000);
}
#[test]
fn test_outbound_message_exactly_at_limit() {
let exact_content = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH);
let msg = OutboundMessage::new("discord", "ch-100", &exact_content);
let payload = DiscordChannel::build_send_payload(&msg).expect("should build");
let content = payload["content"].as_str().unwrap();
assert_eq!(content.len(), DISCORD_MAX_MESSAGE_LENGTH);
assert!(!content.ends_with("..."));
}
#[test]
fn test_outbound_message_one_over_limit() {
let over_content = "b".repeat(DISCORD_MAX_MESSAGE_LENGTH + 1);
let msg = OutboundMessage::new("discord", "ch-100", &over_content);
let payload = DiscordChannel::build_send_payload(&msg).expect("should build");
let content = payload["content"].as_str().unwrap();
assert!(content.len() <= DISCORD_MAX_MESSAGE_LENGTH);
assert!(content.ends_with("..."));
}
#[test]
fn test_gateway_intents_bitmask() {
assert_eq!(GATEWAY_INTENTS, 1 + 512 + 4096 + 32768);
assert_eq!(GATEWAY_INTENTS, 37377);
}
#[test]
fn test_identify_payload_has_os() {
let payload_str = DiscordChannel::build_identify_payload("tok");
let payload: Value = serde_json::from_str(&payload_str).expect("valid JSON");
let os_val = payload["d"]["properties"]["os"]
.as_str()
.expect("os field should be a string");
assert_eq!(os_val, std::env::consts::OS);
assert!(!os_val.is_empty());
}
}