pub mod discord;
pub mod matrix;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tracing::{error, warn};
pub const MAX_ATTACHMENT_BYTES: usize = 5 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct Attachment {
pub media_type: String,
pub data: Vec<u8>,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct IncomingMessage {
pub id: String,
pub sender: String,
pub content: String,
pub room_id: String,
pub timestamp: u64,
pub thread_id: Option<String>,
pub attachments: Vec<Attachment>,
}
#[derive(Debug, Clone)]
pub struct OutgoingMessage {
pub content: String,
pub room_id: String,
pub thread_id: Option<String>,
}
impl OutgoingMessage {
pub fn new(content: impl Into<String>, room_id: impl Into<String>) -> Self {
Self {
content: content.into(),
room_id: room_id.into(),
thread_id: None,
}
}
}
#[derive(Debug, Clone)]
pub struct RoomInfo {
pub name: String,
pub description: Option<String>,
pub kind: String,
}
#[async_trait]
#[allow(dead_code)]
pub trait Channel: Send + Sync {
fn name(&self) -> &str;
async fn send(&self, message: &OutgoingMessage) -> anyhow::Result<()>;
async fn listen(&self, tx: tokio::sync::mpsc::Sender<IncomingMessage>) -> anyhow::Result<()>;
async fn start_typing(&self, _room_id: &str) -> anyhow::Result<()> {
Ok(())
}
async fn stop_typing(&self, _room_id: &str) -> anyhow::Result<()> {
Ok(())
}
async fn room_info(&self, _room_id: &str) -> Option<RoomInfo> {
None
}
}
pub struct Channels {
list: Vec<(String, Arc<dyn Channel>)>,
routing: RwLock<HashMap<String, String>>,
}
impl Channels {
pub fn new(list: Vec<(String, Arc<dyn Channel>)>, seed: HashMap<String, String>) -> Self {
Self {
list,
routing: RwLock::new(seed),
}
}
pub fn names(&self) -> Vec<&str> {
self.list.iter().map(|(n, _)| n.as_str()).collect()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.list.is_empty()
}
pub async fn channel_name_for_room(&self, room_id: &str) -> Option<String> {
self.routing.read().await.get(room_id).cloned()
}
fn channel_by_name(&self, name: &str) -> Option<&Arc<dyn Channel>> {
self.list.iter().find(|(n, _)| n == name).map(|(_, c)| c)
}
async fn channel_for_room_or_first(&self, room_id: &str) -> Option<Arc<dyn Channel>> {
if let Some(name) = self.channel_name_for_room(room_id).await
&& let Some(ch) = self.channel_by_name(&name)
{
return Some(Arc::clone(ch));
}
if self.list.len() == 1 {
return Some(Arc::clone(&self.list[0].1));
}
None
}
pub async fn send(&self, msg: &OutgoingMessage) -> Result<()> {
let ch = self
.channel_for_room_or_first(&msg.room_id)
.await
.ok_or_else(|| anyhow!("no channel registered for room {}", msg.room_id))?;
ch.send(msg).await
}
pub async fn start_typing(&self, room_id: &str) -> Result<()> {
if let Some(ch) = self.channel_for_room_or_first(room_id).await {
ch.start_typing(room_id).await?;
}
Ok(())
}
pub async fn room_info(&self, room_id: &str) -> Option<RoomInfo> {
let ch = self.channel_for_room_or_first(room_id).await?;
ch.room_info(room_id).await
}
pub async fn stop_typing(&self, room_id: &str) -> Result<()> {
if let Some(ch) = self.channel_for_room_or_first(room_id).await {
ch.stop_typing(room_id).await?;
}
Ok(())
}
pub async fn listen_all(self: Arc<Self>, tx: mpsc::Sender<IncomingMessage>) -> Result<()> {
if self.list.is_empty() {
return Err(anyhow!("listen_all called with no channels registered"));
}
let mut handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
for entry in &self.list {
let (name, ch) = (entry.0.clone(), Arc::clone(&entry.1));
let outer_tx = tx.clone();
let me = Arc::clone(&self);
let (inner_tx, mut inner_rx) = mpsc::channel::<IncomingMessage>(64);
let listen_name = name.clone();
handles.push(tokio::spawn(async move {
if let Err(e) = ch.listen(inner_tx).await {
error!("Channel '{listen_name}' listen error: {e:#}");
}
}));
let forward_name = name.clone();
handles.push(tokio::spawn(async move {
while let Some(msg) = inner_rx.recv().await {
me.routing
.write()
.await
.insert(msg.room_id.clone(), forward_name.clone());
if outer_tx.send(msg).await.is_err() {
warn!("Channel '{forward_name}' forwarder: receiver closed");
break;
}
}
}));
}
for h in handles {
let _ = h.await;
}
Ok(())
}
}
pub fn seed_routing_from_config(config: &crate::config::Config) -> HashMap<String, String> {
let mut seed = HashMap::new();
if let Some(m) = &config.matrix {
for r in &m.room_ids {
seed.insert(r.clone(), "matrix".to_string());
}
}
if let Some(d) = &config.discord {
for c in &d.channel_ids {
seed.insert(c.clone(), "discord".to_string());
}
}
seed
}