Skip to main content

aiway_plugin/
lib.rs

1//! # 插件
2//! 插件是网关实现功能扩展的核心组件。插件目前仅支持使用Rust开发,并导出为`.so`格式的动态库给网关使用。
3//!
4//! ## 插件分类
5//! 按照插件的执行范围,可以分为全局插件和路由插件。
6//!
7//! ### 全局插件
8//! 全局插件对整个网关的所有请求生效(不含控制台请求,因为控制台是独立的)。
9//!
10//! 执行阶段:
11//! - 请求阶段:在请求到达API处理端点前执行,可对请求改写、安全验证、限流、缓存等。
12//! - 响应阶段:在API处理完成,响应客户端前执行,可修改响应、记录日志等。
13//!
14//! ### 路由插件
15//! 对特定路由生效。
16//!
17//! 路由插件和全局插件实现方式相同,仅执行时机不同。
18//!
19//! 执行阶段:
20//! - 请求阶段:在全局插件执行后,到达API处理端点前执行。
21//! - 响应阶段:在API处理完成,全局插件执行前执行。
22//!
23//! 注意:全局插件的优先级高于路由插件。
24//!
25//! ### 错误处理
26//! 插件执行时可能发生错误,当某个插件返回`Err`时,插件执行流程会中断,整个请求将失败,网关将返回`502`错误码。
27//!
28//! ## 使用方式
29//! ```rust
30//! use aiway_plugin::protocol::gateway::HttpContext;
31//! use aiway_plugin::serde_json::Value;
32//! use aiway_plugin::{Plugin, PluginError, PluginInfo, Version, async_trait, export, plugin_version};
33//!
34//! // 示例插件
35//! pub struct DemoPlugin;
36//!
37//! impl DemoPlugin {
38//!     pub fn new() -> Self {
39//!         Self {}
40//!     }
41//! }
42//!
43//! #[async_trait]
44//! impl Plugin for DemoPlugin {
45//!     fn name(&self) -> &'static str {
46//!         "demo"
47//!     }
48//!
49//!     fn info(&self) -> PluginInfo {
50//!         PluginInfo {
51//!             version: plugin_version!(),
52//!             default_config: Default::default(),
53//!             description: "Demo Plugin".to_string(),
54//!         }
55//!     }
56//!
57//!     // 实现插件逻辑
58//!     async fn execute(&self, _context: &HttpContext, _config: &Value) -> Result<Value, PluginError> {
59//!         //println!("run demo plugin, context: {:?}", context);
60//!         //println!("config: {:?}", config);
61//!         Ok(Default::default())
62//!     }
63//! }
64//!
65//! // 导出插件
66//! export!(DemoPlugin);
67//! ```
68//!
69//! ## 插件仓库
70//! https://github.com/xgpxg/aiway-plugins
71//!
72
73mod macros;
74mod manager;
75mod network;
76
77use crate::network::NETWORK;
78#[cfg(feature = "model")]
79pub use aiway_model_protocol as model_protocol;
80pub use aiway_protocol as protocol;
81pub use async_trait::async_trait;
82use libloading::Symbol;
83pub use manager::PluginManager;
84use protocol::context::HttpContext;
85pub use semver::Version;
86use serde::{Deserialize, Serialize};
87pub use serde_json;
88use serde_json::Value;
89use std::env::temp_dir;
90use std::fs;
91use std::fs::File;
92use std::io::Write;
93use std::path::PathBuf;
94#[derive(Debug)]
95pub enum PluginError {
96    /// 执行插件业务逻辑时的错误
97    ExecuteError(String),
98    /// 插件不存在
99    NotFound(String),
100    /// 从磁盘或网络加载插件时错误
101    LoadError(String),
102}
103
104impl std::fmt::Display for PluginError {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            PluginError::ExecuteError(msg) => write!(f, "{}", msg),
108            PluginError::NotFound(msg) => write!(f, "{}", msg),
109            PluginError::LoadError(msg) => write!(f, "{}", msg),
110        }
111    }
112}
113
114/// 插件定义
115///
116/// - name
117///
118/// 插件的名称,原则上不要重复。在`PluginManager`中,如果重复了,后添加的将被覆盖。
119///
120/// - execute
121///
122/// `execute`接收HttpContext参数,该HttpContext是可变的(内部可变性),可在插件逻辑内部修改请求和响应。
123/// 注意:当多个插件修改HttpContext的同一个属性时,后执行的插件会覆盖前一个插件的修改。
124/// 插件实现方应该自行决定插件运行阶段(请求阶段或者响应阶段),从而获取或修改request或response的数据。
125///
126/// - 返回值
127/// 返回[serde_json:Value]
128///
129#[async_trait]
130pub trait Plugin: Send + Sync {
131    /// 插件名称
132    fn name(&self) -> &str;
133    /// 插件信息
134    fn info(&self) -> PluginInfo;
135    /// 执行插件
136    async fn execute(&self, context: &HttpContext, config: &Value) -> Result<Value, PluginError>;
137}
138
139/// 插件信息
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct PluginInfo {
142    /// 插件版本
143    pub version: Version,
144    /// 默认配置
145    pub default_config: Value,
146    /// 描述
147    pub description: String,
148}
149
150impl TryFrom<PathBuf> for Box<dyn Plugin> {
151    type Error = PluginError;
152
153    fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
154        unsafe {
155            let lib = libloading::Library::new(&value)
156                .map_err(|e| PluginError::LoadError(e.to_string()))?;
157
158            let create_plugin: Symbol<unsafe extern "C" fn() -> *mut dyn Plugin> = lib
159                .get(b"create_plugin")
160                .map_err(|e| PluginError::LoadError(e.to_string()))?;
161
162            let plugin_ptr = create_plugin();
163
164            if plugin_ptr.is_null() {
165                return Err(PluginError::LoadError(
166                    "Failed to create plugin: ptr is null".to_string(),
167                ));
168            }
169
170            let plugin = Box::from_raw(plugin_ptr);
171
172            // 包装一层,保持对lib的引用
173            let wrapped_plugin = Box::new(LibraryPluginWrapper { plugin, _lib: lib });
174
175            Ok(wrapped_plugin)
176        }
177    }
178}
179
180struct LibraryPluginWrapper {
181    plugin: Box<dyn Plugin>,
182    _lib: libloading::Library,
183}
184
185#[async_trait]
186impl Plugin for LibraryPluginWrapper {
187    fn name(&self) -> &str {
188        self.plugin.name()
189    }
190
191    fn info(&self) -> PluginInfo {
192        self.plugin.info()
193    }
194
195    async fn execute(&self, context: &HttpContext, config: &Value) -> Result<Value, PluginError> {
196        self.plugin.execute(context, config).await
197    }
198}
199
200impl Drop for LibraryPluginWrapper {
201    fn drop(&mut self) {
202        unsafe {
203            let destructor: Symbol<unsafe extern "C" fn(*mut dyn Plugin)> = self
204                ._lib
205                .get(b"destroy_plugin")
206                .expect("Failed to get destructor function");
207
208            destructor(self.plugin.as_mut());
209        }
210    }
211}
212
213/// 从指定的URL加载插件
214pub struct NetworkPlugin(pub String);
215
216#[async_trait]
217pub trait AsyncTryInto<T>: Sized {
218    type Error;
219
220    async fn async_try_into(self) -> Result<T, Self::Error>;
221}
222
223#[async_trait]
224impl AsyncTryInto<Box<dyn Plugin>> for NetworkPlugin {
225    type Error = PluginError;
226
227    async fn async_try_into(self) -> Result<Box<dyn Plugin>, Self::Error> {
228        let response = NETWORK
229            .client
230            .get(&self.0)
231            .send()
232            .await
233            .map_err(|e| PluginError::LoadError(e.to_string()))?
234            .error_for_status()
235            .map_err(|e| PluginError::LoadError(e.to_string()))?;
236
237        let bytes = response
238            .bytes()
239            .await
240            .map_err(|e| PluginError::LoadError(e.to_string()))?;
241
242        let tpf = temp_dir().join(uuid::Uuid::new_v4().to_string());
243
244        let plugin = {
245            let tpf = tpf.clone();
246            let mut file = File::create(&tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
247
248            file.write_all(&bytes)
249                .map_err(|e| PluginError::LoadError(e.to_string()))?;
250
251            drop(file);
252
253            tpf.try_into()
254        };
255
256        fs::remove_file(tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
257
258        plugin
259    }
260}
261
262impl TryFrom<Vec<u8>> for Box<dyn Plugin> {
263    type Error = PluginError;
264
265    fn try_from(from: Vec<u8>) -> Result<Box<dyn Plugin>, Self::Error> {
266        let temp = temp_dir().join(format!("{}.so", uuid::Uuid::new_v4()));
267        fs::write(&temp, from).map_err(|e| PluginError::LoadError(e.to_string()))?;
268        temp.try_into()
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::manager::PluginManager;
276    use std::io::Read;
277    #[tokio::test]
278    async fn test_network_plugin() {
279        let p = NetworkPlugin(
280            "http://192.168.1.242:10000/aiway/test/plugins/libdemo_plugin.so".to_string(),
281        );
282        let plugin: Box<dyn Plugin> = p.async_try_into().await.unwrap();
283        plugin
284            .execute(&HttpContext::default(), &Value::Null)
285            .await
286            .unwrap();
287    }
288    #[tokio::test]
289    async fn test_plugin_manager() {
290        let p = NetworkPlugin(
291            "http://192.168.1.242:10000/aiway/test/plugins/libdemo_plugin.so".to_string(),
292        );
293        let plugin: Box<dyn Plugin> = p.async_try_into().await.unwrap();
294        let mut manager = PluginManager::new();
295        manager.register(plugin);
296        manager
297            .run("demo", &HttpContext::default(), &Value::Null)
298            .await
299            .unwrap();
300    }
301
302    #[tokio::test]
303    async fn test_plugin_from_bytes() {
304        let file =
305            File::open("../../target/release/libaha_model_request_wrapper_plugin.so").unwrap();
306        // 获取file的bytes
307        let bytes = file.bytes().collect::<Result<Vec<_>, _>>().unwrap();
308        let plugin: Box<dyn Plugin> = bytes.try_into().unwrap();
309        println!("{:?}", plugin.info());
310    }
311}