Skip to main content

aiway_plugin/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod macros;
4mod network;
5
6use crate::network::NETWORK;
7// #[cfg(feature = "model")]
8// pub use aiway_model_protocol as model_protocol;
9pub use aiway_protocol as protocol;
10use aiway_protocol::context::http::{request, response};
11pub use async_trait::async_trait;
12pub use bytes::Bytes;
13use libloading::Symbol;
14pub use log;
15use protocol::context::HttpContext;
16pub use semver::Version;
17use serde::{Deserialize, Serialize};
18pub use serde_json;
19use serde_json::Value;
20use std::env::temp_dir;
21use std::fs;
22use std::fs::File;
23use std::io::Write;
24use std::path::PathBuf;
25#[derive(Debug)]
26pub enum PluginError {
27    /// 执行插件业务逻辑时的错误
28    ExecuteError(String),
29    /// 插件不存在
30    NotFound(String),
31    /// 从磁盘或网络加载插件时错误
32    LoadError(String),
33}
34
35impl std::fmt::Display for PluginError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            PluginError::ExecuteError(msg) => write!(f, "{}", msg),
39            PluginError::NotFound(msg) => write!(f, "{}", msg),
40            PluginError::LoadError(msg) => write!(f, "{}", msg),
41        }
42    }
43}
44
45#[async_trait]
46pub trait Plugin: Send + Sync {
47    /// 插件名称
48    fn name(&self) -> &str;
49    /// 插件信息
50    fn info(&self) -> PluginInfo;
51
52    /// 请求阶段,可改写头部
53    async fn on_request(
54        &self,
55        _config: &Value,
56        _head: &mut request::Parts,
57        _ctx: &mut HttpContext,
58    ) -> Result<(), PluginError> {
59        Ok(())
60    }
61
62    /// 请求体阶段,可改写请求体
63    async fn on_request_body(
64        &self,
65        _config: &Value,
66        _body: &mut Option<Bytes>,
67        _ctx: &mut HttpContext,
68    ) -> Result<(), PluginError> {
69        Ok(())
70    }
71
72    /// 响应阶段,可改写头部
73    async fn on_response(
74        &self,
75        _config: &Value,
76        _head: &mut response::Parts,
77        _ctx: &mut HttpContext,
78    ) -> Result<(), PluginError> {
79        Ok(())
80    }
81
82    /// 响应体阶段,可改写响应体
83    fn on_response_body(
84        &self,
85        _config: &Value,
86        _body: &mut Option<Bytes>,
87        _ctx: &mut HttpContext,
88    ) -> Result<(), PluginError> {
89        Ok(())
90    }
91
92    async fn on_logging(&self, _: &Value, _: &mut HttpContext) {}
93}
94
95/// 插件信息
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PluginInfo {
98    /// 插件版本
99    pub version: Version,
100    /// 默认配置
101    pub default_config: Value,
102    /// 描述
103    pub description: String,
104}
105
106impl TryFrom<PathBuf> for Box<dyn Plugin> {
107    type Error = PluginError;
108
109    fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
110        unsafe {
111            let lib = libloading::Library::new(&value)
112                .map_err(|e| PluginError::LoadError(e.to_string()))?;
113
114            let create_plugin: Symbol<unsafe extern "C" fn() -> *mut dyn Plugin> = lib
115                .get(b"create_plugin")
116                .map_err(|e| PluginError::LoadError(e.to_string()))?;
117
118            let plugin_ptr = create_plugin();
119
120            if plugin_ptr.is_null() {
121                return Err(PluginError::LoadError(
122                    "Failed to create plugin: ptr is null".to_string(),
123                ));
124            }
125
126            let plugin = Box::from_raw(plugin_ptr);
127
128            // 包装一层,保持对lib的引用
129            let wrapped_plugin = Box::new(LibraryPluginWrapper { plugin, _lib: lib });
130
131            Ok(wrapped_plugin)
132        }
133    }
134}
135
136struct LibraryPluginWrapper {
137    plugin: Box<dyn Plugin>,
138    _lib: libloading::Library,
139}
140
141#[async_trait]
142impl Plugin for LibraryPluginWrapper {
143    fn name(&self) -> &str {
144        self.plugin.name()
145    }
146
147    fn info(&self) -> PluginInfo {
148        self.plugin.info()
149    }
150
151    async fn on_request(
152        &self,
153        config: &Value,
154        head: &mut request::Parts,
155        ctx: &mut HttpContext,
156    ) -> Result<(), PluginError> {
157        self.plugin.on_request(config, head, ctx).await
158    }
159
160    async fn on_request_body(
161        &self,
162        config: &Value,
163        body: &mut Option<Bytes>,
164        ctx: &mut HttpContext,
165    ) -> Result<(), PluginError> {
166        self.plugin.on_request_body(config, body, ctx).await
167    }
168    async fn on_response(
169        &self,
170        config: &Value,
171        head: &mut response::Parts,
172        ctx: &mut HttpContext,
173    ) -> Result<(), PluginError> {
174        self.plugin.on_response(config, head, ctx).await
175    }
176    fn on_response_body(
177        &self,
178        config: &Value,
179        body: &mut Option<Bytes>,
180        ctx: &mut HttpContext,
181    ) -> Result<(), PluginError> {
182        self.plugin.on_response_body(config, body, ctx)
183    }
184
185    async fn on_logging(&self, _: &Value, _: &mut HttpContext) {}
186
187}
188
189impl Drop for LibraryPluginWrapper {
190    fn drop(&mut self) {
191        unsafe {
192            let destructor: Symbol<unsafe extern "C" fn(*mut dyn Plugin)> = self
193                ._lib
194                .get(b"destroy_plugin")
195                .expect("Failed to get destructor function");
196
197            destructor(self.plugin.as_mut());
198        }
199    }
200}
201
202/// 从指定的URL加载插件
203pub struct NetworkPlugin(pub String);
204
205#[async_trait]
206pub trait AsyncTryInto<T>: Sized {
207    type Error;
208
209    async fn async_try_into(self) -> Result<T, Self::Error>;
210}
211
212#[async_trait]
213impl AsyncTryInto<Box<dyn Plugin>> for NetworkPlugin {
214    type Error = PluginError;
215
216    async fn async_try_into(self) -> Result<Box<dyn Plugin>, Self::Error> {
217        let response = NETWORK
218            .client
219            .get(&self.0)
220            .send()
221            .await
222            .map_err(|e| PluginError::LoadError(e.to_string()))?
223            .error_for_status()
224            .map_err(|e| PluginError::LoadError(e.to_string()))?;
225
226        let bytes = response
227            .bytes()
228            .await
229            .map_err(|e| PluginError::LoadError(e.to_string()))?;
230
231        let tpf = temp_dir().join(uuid::Uuid::new_v4().to_string());
232
233        let plugin = {
234            let tpf = tpf.clone();
235            let mut file = File::create(&tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
236
237            file.write_all(&bytes)
238                .map_err(|e| PluginError::LoadError(e.to_string()))?;
239
240            drop(file);
241
242            tpf.try_into()
243        };
244
245        fs::remove_file(tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
246
247        plugin
248    }
249}
250
251impl TryFrom<Vec<u8>> for Box<dyn Plugin> {
252    type Error = PluginError;
253
254    fn try_from(from: Vec<u8>) -> Result<Box<dyn Plugin>, Self::Error> {
255        let temp = temp_dir().join(format!("{}.so", uuid::Uuid::new_v4()));
256        fs::write(&temp, from).map_err(|e| PluginError::LoadError(e.to_string()))?;
257        temp.try_into()
258    }
259}