Skip to main content

aiway_plugin/
lib.rs

1#[doc = include_str!("../README.md")]
2mod macros;
3mod network;
4
5use crate::network::NETWORK;
6// #[cfg(feature = "model")]
7// pub use aiway_model_protocol as model_protocol;
8pub use aiway_protocol as protocol;
9use aiway_protocol::context::http::{request, response};
10pub use async_trait::async_trait;
11pub use bytes::Bytes;
12use libloading::Symbol;
13pub use log;
14use protocol::context::HttpContext;
15pub use semver::Version;
16use serde::{Deserialize, Serialize};
17pub use serde_json;
18use serde_json::Value;
19use std::env::temp_dir;
20use std::fs;
21use std::fs::File;
22use std::io::Write;
23use std::path::PathBuf;
24#[derive(Debug)]
25pub enum PluginError {
26    /// 执行插件业务逻辑时的错误
27    ExecuteError(String),
28    /// 插件不存在
29    NotFound(String),
30    /// 从磁盘或网络加载插件时错误
31    LoadError(String),
32}
33
34impl std::fmt::Display for PluginError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            PluginError::ExecuteError(msg) => write!(f, "{}", msg),
38            PluginError::NotFound(msg) => write!(f, "{}", msg),
39            PluginError::LoadError(msg) => write!(f, "{}", msg),
40        }
41    }
42}
43
44#[async_trait]
45pub trait Plugin: Send + Sync {
46    /// 插件名称
47    fn name(&self) -> &str;
48    /// 插件信息
49    fn info(&self) -> PluginInfo;
50
51    /// 请求阶段,可改写头部
52    async fn on_request(
53        &self,
54        _config: &Value,
55        _head: &mut request::Parts,
56        _ctx: &mut HttpContext,
57    ) -> Result<(), PluginError> {
58        Ok(())
59    }
60
61    /// 请求体阶段,可改写请求体
62    async fn on_request_body(
63        &self,
64        _config: &Value,
65        _body: &mut Option<Bytes>,
66        _ctx: &mut HttpContext,
67    ) -> Result<(), PluginError> {
68        Ok(())
69    }
70
71    /// 响应阶段,可改写头部
72    async fn on_response(
73        &self,
74        _config: &Value,
75        _head: &mut response::Parts,
76        _ctx: &mut HttpContext,
77    ) -> Result<(), PluginError> {
78        Ok(())
79    }
80
81    /// 响应体阶段,可改写响应体
82    fn on_response_body(
83        &self,
84        _config: &Value,
85        _body: &mut Option<Bytes>,
86        _ctx: &mut HttpContext,
87    ) -> Result<(), PluginError> {
88        Ok(())
89    }
90
91    async fn on_logging(&self, _: &Value, _: &mut HttpContext) {}
92}
93
94/// 插件信息
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct PluginInfo {
97    /// 插件版本
98    pub version: Version,
99    /// 默认配置
100    pub default_config: Value,
101    /// 描述
102    pub description: String,
103}
104
105impl TryFrom<PathBuf> for Box<dyn Plugin> {
106    type Error = PluginError;
107
108    fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
109        unsafe {
110            let lib = libloading::Library::new(&value)
111                .map_err(|e| PluginError::LoadError(e.to_string()))?;
112
113            let create_plugin: Symbol<unsafe extern "C" fn() -> *mut dyn Plugin> = lib
114                .get(b"create_plugin")
115                .map_err(|e| PluginError::LoadError(e.to_string()))?;
116
117            let plugin_ptr = create_plugin();
118
119            if plugin_ptr.is_null() {
120                return Err(PluginError::LoadError(
121                    "Failed to create plugin: ptr is null".to_string(),
122                ));
123            }
124
125            let plugin = Box::from_raw(plugin_ptr);
126
127            // 包装一层,保持对lib的引用
128            let wrapped_plugin = Box::new(LibraryPluginWrapper { plugin, _lib: lib });
129
130            Ok(wrapped_plugin)
131        }
132    }
133}
134
135struct LibraryPluginWrapper {
136    plugin: Box<dyn Plugin>,
137    _lib: libloading::Library,
138}
139
140#[async_trait]
141impl Plugin for LibraryPluginWrapper {
142    fn name(&self) -> &str {
143        self.plugin.name()
144    }
145
146    fn info(&self) -> PluginInfo {
147        self.plugin.info()
148    }
149
150    async fn on_request(
151        &self,
152        config: &Value,
153        head: &mut request::Parts,
154        ctx: &mut HttpContext,
155    ) -> Result<(), PluginError> {
156        self.plugin.on_request(config, head, ctx).await
157    }
158
159    async fn on_request_body(
160        &self,
161        config: &Value,
162        body: &mut Option<Bytes>,
163        ctx: &mut HttpContext,
164    ) -> Result<(), PluginError> {
165        self.plugin.on_request_body(config, body, ctx).await
166    }
167    async fn on_response(
168        &self,
169        config: &Value,
170        head: &mut response::Parts,
171        ctx: &mut HttpContext,
172    ) -> Result<(), PluginError> {
173        self.plugin.on_response(config, head, ctx).await
174    }
175    fn on_response_body(
176        &self,
177        config: &Value,
178        body: &mut Option<Bytes>,
179        ctx: &mut HttpContext,
180    ) -> Result<(), PluginError> {
181        self.plugin.on_response_body(config, body, ctx)
182    }
183
184    async fn on_logging(&self, _: &Value, _: &mut HttpContext) {}
185
186}
187
188impl Drop for LibraryPluginWrapper {
189    fn drop(&mut self) {
190        unsafe {
191            let destructor: Symbol<unsafe extern "C" fn(*mut dyn Plugin)> = self
192                ._lib
193                .get(b"destroy_plugin")
194                .expect("Failed to get destructor function");
195
196            destructor(self.plugin.as_mut());
197        }
198    }
199}
200
201/// 从指定的URL加载插件
202pub struct NetworkPlugin(pub String);
203
204#[async_trait]
205pub trait AsyncTryInto<T>: Sized {
206    type Error;
207
208    async fn async_try_into(self) -> Result<T, Self::Error>;
209}
210
211#[async_trait]
212impl AsyncTryInto<Box<dyn Plugin>> for NetworkPlugin {
213    type Error = PluginError;
214
215    async fn async_try_into(self) -> Result<Box<dyn Plugin>, Self::Error> {
216        let response = NETWORK
217            .client
218            .get(&self.0)
219            .send()
220            .await
221            .map_err(|e| PluginError::LoadError(e.to_string()))?
222            .error_for_status()
223            .map_err(|e| PluginError::LoadError(e.to_string()))?;
224
225        let bytes = response
226            .bytes()
227            .await
228            .map_err(|e| PluginError::LoadError(e.to_string()))?;
229
230        let tpf = temp_dir().join(uuid::Uuid::new_v4().to_string());
231
232        let plugin = {
233            let tpf = tpf.clone();
234            let mut file = File::create(&tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
235
236            file.write_all(&bytes)
237                .map_err(|e| PluginError::LoadError(e.to_string()))?;
238
239            drop(file);
240
241            tpf.try_into()
242        };
243
244        fs::remove_file(tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
245
246        plugin
247    }
248}
249
250impl TryFrom<Vec<u8>> for Box<dyn Plugin> {
251    type Error = PluginError;
252
253    fn try_from(from: Vec<u8>) -> Result<Box<dyn Plugin>, Self::Error> {
254        let temp = temp_dir().join(format!("{}.so", uuid::Uuid::new_v4()));
255        fs::write(&temp, from).map_err(|e| PluginError::LoadError(e.to_string()))?;
256        temp.try_into()
257    }
258}