systemprompt-extension 0.2.0

Extension framework for systemprompt.io - register custom modules, providers, and APIs
Documentation
use std::any::TypeId;
use std::collections::HashMap;

use crate::any::AnyExtension;
use crate::error::LoaderError;
pub use crate::registry::RESERVED_PATHS;
#[cfg(feature = "web")]
use crate::typed::ApiExtensionTypedDyn;
use crate::typed::SchemaExtensionTyped;
use crate::types::ExtensionType;

pub struct TypedExtensionRegistry {
    extensions: Vec<Box<dyn AnyExtension>>,
    by_id: HashMap<String, usize>,
    by_type: HashMap<TypeId, usize>,
    api_paths: Vec<String>,
}

impl Default for TypedExtensionRegistry {
    fn default() -> Self {
        Self::new()
    }
}

impl TypedExtensionRegistry {
    #[must_use]
    pub fn new() -> Self {
        Self {
            extensions: Vec::new(),
            by_id: HashMap::new(),
            by_type: HashMap::new(),
            api_paths: Vec::new(),
        }
    }

    pub(crate) fn add_boxed(&mut self, ext: Box<dyn AnyExtension>) {
        let idx = self.extensions.len();
        self.by_id.insert(ext.id().to_string(), idx);

        #[cfg(feature = "web")]
        if let Some(api) = ext.as_api() {
            self.api_paths.push(api.base_path().to_string());
        }

        self.extensions.push(ext);
    }

    pub fn validate_api_path(&self, extension_id: &str, path: &str) -> Result<(), LoaderError> {
        if !path.starts_with("/api/") && !path.starts_with("/.") {
            return Err(LoaderError::InvalidBasePath {
                extension: extension_id.to_string(),
                path: path.to_string(),
            });
        }

        for reserved in RESERVED_PATHS {
            if path.starts_with(reserved) {
                return Err(LoaderError::ReservedPathCollision {
                    extension: extension_id.to_string(),
                    path: path.to_string(),
                });
            }
        }

        for existing in &self.api_paths {
            if path.starts_with(existing.as_str()) || existing.starts_with(path) {
                return Err(LoaderError::ReservedPathCollision {
                    extension: extension_id.to_string(),
                    path: format!("{} (conflicts with {})", path, existing),
                });
            }
        }

        Ok(())
    }

    #[must_use]
    pub fn has_type<E: ExtensionType>(&self) -> bool {
        self.by_type.contains_key(&TypeId::of::<E>())
    }

    #[must_use]
    pub fn has(&self, id: &str) -> bool {
        self.by_id.contains_key(id)
    }

    #[must_use]
    pub fn get(&self, id: &str) -> Option<&dyn AnyExtension> {
        self.by_id.get(id).map(|&idx| self.extensions[idx].as_ref())
    }

    #[must_use]
    pub fn get_typed<E: ExtensionType + 'static>(&self) -> Option<&E> {
        self.by_type
            .get(&TypeId::of::<E>())
            .and_then(|&idx| self.extensions[idx].as_any().downcast_ref())
    }

    pub fn schema_extensions(&self) -> impl Iterator<Item = &dyn SchemaExtensionTyped> {
        let mut schemas: Vec<_> = self
            .extensions
            .iter()
            .filter_map(|e| e.as_schema())
            .collect();
        schemas.sort_by_key(|s| s.migration_weight());
        schemas.into_iter()
    }

    #[cfg(feature = "web")]
    pub fn api_extensions(&self) -> impl Iterator<Item = &dyn ApiExtensionTypedDyn> {
        self.extensions.iter().filter_map(|e| e.as_api())
    }

    pub fn all_extensions(&self) -> impl Iterator<Item = &dyn AnyExtension> {
        self.extensions.iter().map(AsRef::as_ref)
    }

    #[must_use]
    pub fn api_paths(&self) -> &[String] {
        &self.api_paths
    }

    #[must_use]
    pub fn len(&self) -> usize {
        self.extensions.len()
    }

    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.extensions.is_empty()
    }
}

impl std::fmt::Debug for TypedExtensionRegistry {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TypedExtensionRegistry")
            .field("count", &self.extensions.len())
            .field("ids", &self.by_id.keys().collect::<Vec<_>>())
            .field("api_paths", &self.api_paths)
            .finish_non_exhaustive()
    }
}