systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use async_trait::async_trait;
use systemprompt_database::DbPool;
use systemprompt_identifiers::{ContextId, SessionId, UserId};
use systemprompt_traits::{ContextProvider, ContextProviderError, ContextWithStats};

use crate::repository::ContextRepository;

#[derive(Debug, Clone)]
pub struct ContextProviderService {
    repo: ContextRepository,
}

impl ContextProviderService {
    pub fn new(db_pool: &DbPool) -> anyhow::Result<Self> {
        Ok(Self {
            repo: ContextRepository::new(db_pool)?,
        })
    }
}

#[async_trait]
impl ContextProvider for ContextProviderService {
    async fn list_contexts_with_stats(
        &self,
        user_id: &UserId,
    ) -> Result<Vec<ContextWithStats>, ContextProviderError> {
        let contexts = self
            .repo
            .list_contexts_with_stats(user_id)
            .await
            .map_err(|e| ContextProviderError::Database(e.to_string()))?;

        Ok(contexts
            .into_iter()
            .map(|c| ContextWithStats {
                context_id: c.context_id,
                user_id: c.user_id,
                name: c.name,
                created_at: c.created_at,
                updated_at: c.updated_at,
                task_count: c.task_count,
                message_count: c.message_count,
                last_message_at: c.last_message_at,
            })
            .collect())
    }

    async fn get_context(
        &self,
        context_id: &ContextId,
        user_id: &UserId,
    ) -> Result<ContextWithStats, ContextProviderError> {
        let context = self
            .repo
            .get_context(context_id, user_id)
            .await
            .map_err(|e| match e {
                systemprompt_traits::RepositoryError::NotFound(msg) => {
                    ContextProviderError::NotFound(msg)
                },
                other => ContextProviderError::Database(other.to_string()),
            })?;

        let all_contexts = self
            .repo
            .list_contexts_with_stats(user_id)
            .await
            .map_err(|e| ContextProviderError::Database(e.to_string()))?;

        let context_with_stats = all_contexts
            .into_iter()
            .find(|c| c.context_id == context.context_id)
            .ok_or_else(|| {
                ContextProviderError::NotFound(format!("Context {} not found", context_id))
            })?;

        Ok(ContextWithStats {
            context_id: context_with_stats.context_id,
            user_id: context_with_stats.user_id,
            name: context_with_stats.name,
            created_at: context_with_stats.created_at,
            updated_at: context_with_stats.updated_at,
            task_count: context_with_stats.task_count,
            message_count: context_with_stats.message_count,
            last_message_at: context_with_stats.last_message_at,
        })
    }

    async fn create_context(
        &self,
        user_id: &UserId,
        session_id: Option<&SessionId>,
        name: &str,
    ) -> Result<ContextId, ContextProviderError> {
        self.repo
            .create_context(user_id, session_id, name)
            .await
            .map_err(|e| ContextProviderError::Database(e.to_string()))
    }

    async fn update_context_name(
        &self,
        context_id: &ContextId,
        user_id: &UserId,
        name: &str,
    ) -> Result<(), ContextProviderError> {
        self.repo
            .update_context_name(context_id, user_id, name)
            .await
            .map_err(|e| match e {
                systemprompt_traits::RepositoryError::NotFound(msg) => {
                    ContextProviderError::NotFound(msg)
                },
                other => ContextProviderError::Database(other.to_string()),
            })
    }

    async fn delete_context(
        &self,
        context_id: &ContextId,
        user_id: &UserId,
    ) -> Result<(), ContextProviderError> {
        self.repo
            .delete_context(context_id, user_id)
            .await
            .map_err(|e| match e {
                systemprompt_traits::RepositoryError::NotFound(msg) => {
                    ContextProviderError::NotFound(msg)
                },
                other => ContextProviderError::Database(other.to_string()),
            })
    }
}