aiway-plugin 0.1.2

The aiway plugin lib
Documentation
//! # 插件
//! 插件是网关实现功能扩展的核心组件。插件目前仅支持使用Rust开发,并导出为`.so`格式的动态库给网关使用。
//!
//! ## 插件分类
//! 按照插件的执行范围,可以分为全局插件和路由插件。
//!
//! ### 全局插件
//! 全局插件对整个网关的所有请求生效(不含控制台请求,因为控制台是独立的)。
//!
//! 执行阶段:
//! - 请求阶段:在请求到达API处理端点前执行,可对请求改写、安全验证、限流、缓存等。
//! - 响应阶段:在API处理完成,响应客户端前执行,可修改响应、记录日志等。
//!
//! ### 路由插件
//! 对特定路由生效。
//!
//! 路由插件和全局插件实现方式相同,仅执行时机不同。
//!
//! 执行阶段:
//! - 请求阶段:在全局插件执行后,到达API处理端点前执行。
//! - 响应阶段:在API处理完成,全局插件执行前执行。
//!
//! 注意:全局插件的优先级高于路由插件。
//!
//! ### 错误处理
//! 插件执行时可能发生错误,当某个插件返回`Err`时,插件执行流程会中断,整个请求将失败,网关将返回`502`错误码。
//!
//! ## 使用方式
//! ```rust
//! use aiway_plugin::protocol::gateway::HttpContext;
//! use aiway_plugin::serde_json::Value;
//! use aiway_plugin::{Plugin, PluginError, PluginInfo, Version, async_trait, export, plugin_version};
//!
//! // 示例插件
//! pub struct DemoPlugin;
//!
//! impl DemoPlugin {
//!     pub fn new() -> Self {
//!         Self {}
//!     }
//! }
//!
//! #[async_trait]
//! impl Plugin for DemoPlugin {
//!     fn name(&self) -> &'static str {
//!         "demo"
//!     }
//!
//!     fn info(&self) -> PluginInfo {
//!         PluginInfo {
//!             version: plugin_version!(),
//!             default_config: Default::default(),
//!             description: "Demo Plugin".to_string(),
//!         }
//!     }
//!
//!     // 实现插件逻辑
//!     async fn execute(&self, _context: &HttpContext, _config: &Value) -> Result<Value, PluginError> {
//!         //println!("run demo plugin, context: {:?}", context);
//!         //println!("config: {:?}", config);
//!         Ok(Default::default())
//!     }
//! }
//!
//! // 导出插件
//! export!(DemoPlugin);
//! ```
//!
//! ## 插件仓库
//! https://github.com/xgpxg/aiway-plugins
//!

mod macros;
mod manager;
mod network;

use crate::network::NETWORK;
#[cfg(feature = "model")]
pub use aiway_model_protocol as model_protocol;
pub use aiway_protocol as protocol;
pub use async_trait::async_trait;
use libloading::Symbol;
pub use manager::PluginManager;
use protocol::context::HttpContext;
pub use semver::Version;
use serde::{Deserialize, Serialize};
pub use serde_json;
use serde_json::Value;
use std::env::temp_dir;
use std::fs;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
#[derive(Debug)]
pub enum PluginError {
    /// 执行插件业务逻辑时的错误
    ExecuteError(String),
    /// 插件不存在
    NotFound(String),
    /// 从磁盘或网络加载插件时错误
    LoadError(String),
}

impl std::fmt::Display for PluginError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            PluginError::ExecuteError(msg) => write!(f, "{}", msg),
            PluginError::NotFound(msg) => write!(f, "{}", msg),
            PluginError::LoadError(msg) => write!(f, "{}", msg),
        }
    }
}

/// 插件定义
///
/// - name
///
/// 插件的名称,原则上不要重复。在`PluginManager`中,如果重复了,后添加的将被覆盖。
///
/// - execute
///
/// `execute`接收HttpContext参数,该HttpContext是可变的(内部可变性),可在插件逻辑内部修改请求和响应。
/// 注意:当多个插件修改HttpContext的同一个属性时,后执行的插件会覆盖前一个插件的修改。
/// 插件实现方应该自行决定插件运行阶段(请求阶段或者响应阶段),从而获取或修改request或response的数据。
///
/// - 返回值
/// 返回[serde_json:Value]
///
#[async_trait]
pub trait Plugin: Send + Sync {
    /// 插件名称
    fn name(&self) -> &str;
    /// 插件信息
    fn info(&self) -> PluginInfo;
    /// 执行插件
    async fn execute(&self, context: &HttpContext, config: &Value) -> Result<Value, PluginError>;
}

/// 插件信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginInfo {
    /// 插件版本
    pub version: Version,
    /// 默认配置
    pub default_config: Value,
    /// 描述
    pub description: String,
}

impl TryFrom<PathBuf> for Box<dyn Plugin> {
    type Error = PluginError;

    fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
        unsafe {
            let lib = libloading::Library::new(&value)
                .map_err(|e| PluginError::LoadError(e.to_string()))?;

            let create_plugin: Symbol<unsafe extern "C" fn() -> *mut dyn Plugin> = lib
                .get(b"create_plugin")
                .map_err(|e| PluginError::LoadError(e.to_string()))?;

            let plugin_ptr = create_plugin();

            if plugin_ptr.is_null() {
                return Err(PluginError::LoadError(
                    "Failed to create plugin: ptr is null".to_string(),
                ));
            }

            let plugin = Box::from_raw(plugin_ptr);

            // 包装一层,保持对lib的引用
            let wrapped_plugin = Box::new(LibraryPluginWrapper { plugin, _lib: lib });

            Ok(wrapped_plugin)
        }
    }
}

struct LibraryPluginWrapper {
    plugin: Box<dyn Plugin>,
    _lib: libloading::Library,
}

#[async_trait]
impl Plugin for LibraryPluginWrapper {
    fn name(&self) -> &str {
        self.plugin.name()
    }

    fn info(&self) -> PluginInfo {
        self.plugin.info()
    }

    async fn execute(&self, context: &HttpContext, config: &Value) -> Result<Value, PluginError> {
        self.plugin.execute(context, config).await
    }
}

impl Drop for LibraryPluginWrapper {
    fn drop(&mut self) {
        unsafe {
            let destructor: Symbol<unsafe extern "C" fn(*mut dyn Plugin)> = self
                ._lib
                .get(b"destroy_plugin")
                .expect("Failed to get destructor function");

            destructor(self.plugin.as_mut());
        }
    }
}

/// 从指定的URL加载插件
pub struct NetworkPlugin(pub String);

#[async_trait]
pub trait AsyncTryInto<T>: Sized {
    type Error;

    async fn async_try_into(self) -> Result<T, Self::Error>;
}

#[async_trait]
impl AsyncTryInto<Box<dyn Plugin>> for NetworkPlugin {
    type Error = PluginError;

    async fn async_try_into(self) -> Result<Box<dyn Plugin>, Self::Error> {
        let response = NETWORK
            .client
            .get(&self.0)
            .send()
            .await
            .map_err(|e| PluginError::LoadError(e.to_string()))?
            .error_for_status()
            .map_err(|e| PluginError::LoadError(e.to_string()))?;

        let bytes = response
            .bytes()
            .await
            .map_err(|e| PluginError::LoadError(e.to_string()))?;

        let tpf = temp_dir().join(uuid::Uuid::new_v4().to_string());

        let plugin = {
            let tpf = tpf.clone();
            let mut file = File::create(&tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;

            file.write_all(&bytes)
                .map_err(|e| PluginError::LoadError(e.to_string()))?;

            drop(file);

            tpf.try_into()
        };

        fs::remove_file(tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;

        plugin
    }
}

impl TryFrom<Vec<u8>> for Box<dyn Plugin> {
    type Error = PluginError;

    fn try_from(from: Vec<u8>) -> Result<Box<dyn Plugin>, Self::Error> {
        let temp = temp_dir().join(format!("{}.so", uuid::Uuid::new_v4()));
        fs::write(&temp, from).map_err(|e| PluginError::LoadError(e.to_string()))?;
        temp.try_into()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::manager::PluginManager;
    use std::io::Read;
    #[tokio::test]
    async fn test_network_plugin() {
        let p = NetworkPlugin(
            "http://192.168.1.242:10000/aiway/test/plugins/libdemo_plugin.so".to_string(),
        );
        let plugin: Box<dyn Plugin> = p.async_try_into().await.unwrap();
        plugin
            .execute(&HttpContext::default(), &Value::Null)
            .await
            .unwrap();
    }
    #[tokio::test]
    async fn test_plugin_manager() {
        let p = NetworkPlugin(
            "http://192.168.1.242:10000/aiway/test/plugins/libdemo_plugin.so".to_string(),
        );
        let plugin: Box<dyn Plugin> = p.async_try_into().await.unwrap();
        let mut manager = PluginManager::new();
        manager.register(plugin);
        manager
            .run("demo", &HttpContext::default(), &Value::Null)
            .await
            .unwrap();
    }

    #[tokio::test]
    async fn test_plugin_from_bytes() {
        let file =
            File::open("../../target/release/libaha_model_request_wrapper_plugin.so").unwrap();
        // 获取file的bytes
        let bytes = file.bytes().collect::<Result<Vec<_>, _>>().unwrap();
        let plugin: Box<dyn Plugin> = bytes.try_into().unwrap();
        println!("{:?}", plugin.info());
    }
}