sapphire-agent 0.4.0

A personal AI assistant agent with Matrix/Discord channels, Anthropic backend, and a sapphire-workspace memory layer
use crate::channel::{Attachment, Channel, IncomingMessage, MAX_ATTACHMENT_BYTES, OutgoingMessage};
use crate::config::MatrixConfig;
use anyhow::{Context, Result};
use async_trait::async_trait;
use matrix_sdk::{
    Client, SessionMeta, SessionTokens,
    authentication::matrix::MatrixSession,
    config::SyncSettings,
    media::{MediaFormat, MediaRequestParameters},
    ruma::{
        OwnedEventId, OwnedRoomId, OwnedUserId,
        events::relation::Thread,
        events::room::message::{
            ImageMessageEventContent, MessageType, OriginalSyncRoomMessageEvent, Relation,
            RoomMessageEventContent,
        },
    },
};
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{OnceCell, mpsc};
use tracing::{debug, info, warn};

pub struct MatrixChannel {
    homeserver: String,
    access_token: String,
    user_id: String,
    device_id: String,
    room_ids: HashSet<String>,
    allowed_users: HashSet<String>,
    recovery_key: Option<String>,
    state_dir: PathBuf,
    client: Arc<OnceCell<Client>>,
}

impl std::fmt::Debug for MatrixChannel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("MatrixChannel")
            .field("homeserver", &self.homeserver)
            .field("user_id", &self.user_id)
            .field("room_ids", &self.room_ids)
            .finish_non_exhaustive()
    }
}

impl MatrixChannel {
    pub fn new(cfg: &MatrixConfig) -> Self {
        Self {
            homeserver: cfg.homeserver.clone(),
            access_token: cfg.access_token.clone(),
            user_id: cfg.user_id.clone(),
            device_id: cfg.device_id.clone(),
            room_ids: cfg.room_ids.iter().cloned().collect(),
            allowed_users: cfg.allowed_users.iter().cloned().collect(),
            recovery_key: cfg.recovery_key.clone(),
            state_dir: cfg.resolved_state_dir(),
            client: Arc::new(OnceCell::new()),
        }
    }

    async fn get_or_init_client(&self) -> Result<&Client> {
        self.client
            .get_or_try_init(|| async { self.build_client().await })
            .await
    }

    async fn build_client(&self) -> Result<Client> {
        std::fs::create_dir_all(&self.state_dir)
            .with_context(|| format!("Failed to create state dir: {}", self.state_dir.display()))?;

        let client = Client::builder()
            .homeserver_url(&self.homeserver)
            .sqlite_store(&self.state_dir, None)
            .build()
            .await
            .context("Failed to build Matrix client")?;

        let user_id: OwnedUserId = self.user_id.parse().context("Invalid user_id in config")?;

        let session = MatrixSession {
            meta: SessionMeta {
                user_id,
                device_id: self.device_id.as_str().into(),
            },
            tokens: SessionTokens {
                access_token: self.access_token.clone(),
                refresh_token: None,
            },
        };
        client
            .restore_session(session)
            .await
            .context("Failed to restore Matrix session")?;

        if let Some(key) = &self.recovery_key {
            if let Err(e) = client.encryption().recovery().recover(key).await {
                warn!("E2EE recovery failed (key may already be active): {e}");
            } else {
                info!("E2EE recovery key applied successfully");
            }
        }

        info!(
            "Matrix client initialized for {} on {}",
            self.user_id, self.homeserver
        );
        Ok(client)
    }

    async fn get_room(&self, room_id_str: &str) -> Result<matrix_sdk::Room> {
        let client = self.get_or_init_client().await?;
        let room_id: OwnedRoomId = room_id_str.parse().context("Invalid room_id")?;
        client.get_room(&room_id).context("Room not found")
    }
}

/// Download the bytes for a Matrix `m.image` event via the SDK's authenticated
/// media endpoint. Returns `None` (with a warning log) if the file is over the
/// 5 MB budget or the download fails — the caller should drop the event in
/// that case.
async fn download_matrix_image(
    client: &Client,
    image: &ImageMessageEventContent,
) -> Option<Attachment> {
    let mime = image
        .info
        .as_ref()
        .and_then(|info| info.mimetype.clone())
        .unwrap_or_else(|| "image/octet-stream".to_string());

    if !mime.starts_with("image/") {
        return None;
    }

    if let Some(size) = image.info.as_ref().and_then(|info| info.size) {
        let size: u64 = size.into();
        if size as usize > MAX_ATTACHMENT_BYTES {
            warn!(
                "Matrix image '{}' is {} bytes (>5MB); skipping",
                image.body, size
            );
            return None;
        }
    }

    let request = MediaRequestParameters {
        source: image.source.clone(),
        format: MediaFormat::File,
    };

    match client.media().get_media_content(&request, true).await {
        Ok(bytes) if bytes.len() <= MAX_ATTACHMENT_BYTES => Some(Attachment {
            media_type: mime,
            data: bytes,
        }),
        Ok(bytes) => {
            warn!(
                "Matrix image '{}' decoded to {} bytes (>5MB); skipping",
                image.body,
                bytes.len()
            );
            None
        }
        Err(e) => {
            warn!("Failed to download Matrix image '{}': {e}", image.body);
            None
        }
    }
}

#[async_trait]
impl Channel for MatrixChannel {
    fn name(&self) -> &str {
        "matrix"
    }

    async fn send(&self, message: &OutgoingMessage) -> Result<()> {
        let room = self.get_room(&message.room_id).await?;

        let mut content = RoomMessageEventContent::text_markdown(&message.content);

        if let Some(thread_id) = &message.thread_id {
            let thread_root: OwnedEventId = thread_id.parse().context("Invalid thread_id")?;
            content.relates_to = Some(Relation::Thread(Thread::plain(
                thread_root.clone(),
                thread_root,
            )));
        }

        room.send(content).await.context("Failed to send message")?;
        Ok(())
    }

    async fn listen(&self, tx: mpsc::Sender<IncomingMessage>) -> Result<()> {
        let client = self.get_or_init_client().await?;
        let allowed_rooms = self.room_ids.clone();
        let allowed_users = self.allowed_users.clone();
        let bot_user_id = self.user_id.clone();

        client.add_event_handler({
            let tx = tx.clone();
            move |event: OriginalSyncRoomMessageEvent, room: matrix_sdk::Room| {
                let tx = tx.clone();
                let allowed_rooms = allowed_rooms.clone();
                let allowed_users = allowed_users.clone();
                let bot_user_id = bot_user_id.clone();

                async move {
                    let room_id_str = room.room_id().as_str().to_string();
                    if !allowed_rooms.contains(&room_id_str) {
                        return;
                    }
                    if event.sender.as_str() == bot_user_id {
                        return;
                    }
                    if !allowed_users.is_empty() && !allowed_users.contains(event.sender.as_str()) {
                        debug!("Ignoring message from non-allowed user: {}", event.sender);
                        return;
                    }

                    let (content, attachments) = match &event.content.msgtype {
                        MessageType::Text(text_content) => (text_content.body.clone(), Vec::new()),
                        MessageType::Image(image_content) => {
                            let attachment =
                                download_matrix_image(&room.client(), image_content).await;
                            // Use the body as a caption fallback (filename or user-supplied text).
                            let caption = image_content
                                .caption()
                                .map(str::to_string)
                                .unwrap_or_default();
                            match attachment {
                                Some(att) => (caption, vec![att]),
                                None => return,
                            }
                        }
                        _ => return,
                    };

                    let thread_id = event.content.relates_to.as_ref().and_then(|rel| {
                        if let Relation::Thread(thread) = rel {
                            Some(thread.event_id.to_string())
                        } else {
                            None
                        }
                    });

                    let msg = IncomingMessage {
                        id: event.event_id.to_string(),
                        sender: event.sender.to_string(),
                        content,
                        room_id: room_id_str,
                        timestamp: 0,
                        thread_id,
                        attachments,
                    };

                    if let Err(e) = tx.send(msg).await {
                        warn!("Failed to forward Matrix message: {e}");
                    }
                }
            }
        });

        info!("Starting Matrix sync loop...");
        let sync_settings = SyncSettings::default().timeout(std::time::Duration::from_secs(30));
        let min_backoff = std::time::Duration::from_secs(1);
        let max_backoff = std::time::Duration::from_secs(300);
        let stable_threshold = std::time::Duration::from_secs(60);
        let mut backoff = min_backoff;
        loop {
            let started = std::time::Instant::now();
            match client.sync(sync_settings.clone()).await {
                Ok(()) => {
                    warn!("Matrix sync loop exited without error; reconnecting in {backoff:?}");
                }
                Err(e) => {
                    warn!("Matrix sync loop exited with error: {e}; reconnecting in {backoff:?}");
                }
            }
            tokio::time::sleep(backoff).await;
            if started.elapsed() >= stable_threshold {
                backoff = min_backoff;
            } else {
                backoff = (backoff * 2).min(max_backoff);
            }
            info!("Reconnecting Matrix sync loop...");
        }
    }

    async fn start_typing(&self, room_id: &str) -> Result<()> {
        let room = self.get_room(room_id).await?;
        room.typing_notice(true)
            .await
            .context("Failed to send typing notice")?;
        Ok(())
    }

    async fn stop_typing(&self, room_id: &str) -> Result<()> {
        let room = self.get_room(room_id).await?;
        room.typing_notice(false)
            .await
            .context("Failed to stop typing notice")?;
        Ok(())
    }
}