iridis_file_ext/
file_ext.rs1use std::{collections::HashMap, mem::ManuallyDrop, path::PathBuf, sync::Arc};
5
6use crate::prelude::{
7    thirdparty::{libloading, tokio::task::JoinSet},
8    *,
9};
10
11pub struct FileExtManager {
13    pub plugins: HashMap<String, Arc<RuntimeFileExt>>,
14}
15
16pub struct FileExtLoader {
18    pub plugins: JoinSet<Result<(Vec<String>, RuntimeFileExt)>>,
19}
20
21impl FileExtManager {
22    pub fn new(plugins: HashMap<String, Arc<RuntimeFileExt>>) -> Self {
24        Self { plugins }
25    }
26
27    pub async fn load(
30        &self,
31        path: PathBuf,
32        inputs: Inputs,
33        outputs: Outputs,
34        queries: Queries,
35        queryables: Queryables,
36        configuration: serde_yml::Value,
37    ) -> Result<RuntimeNode> {
38        let ext = path
39            .extension()
40            .ok_or_eyre(format!("No extension found for path '{:?}'", path))?
41            .to_str()
42            .ok_or_eyre("Invalid extension")?;
43
44        let plugin = self
45            .plugins
46            .get(ext)
47            .ok_or_eyre(format!("Plugin not found for extension '{}'", ext))?;
48
49        plugin
50            .load(path, inputs, outputs, queries, queryables, configuration)
51            .await
52    }
53}
54
55impl FileExtLoader {
56    pub async fn new() -> Result<Self> {
58        Ok(FileExtLoader {
59            plugins: JoinSet::new(),
60        })
61    }
62
63    pub fn load_statically_linked_plugin<T: FileExtPlugin + 'static>(&mut self) {
70        self.plugins.spawn(async move {
71            let plugin = T::new().await?.wrap_err(format!(
72                "Failed to load static plugin '{}'",
73                std::any::type_name::<T>(),
74            ))?;
75
76            let plugin = RuntimeFileExt::StaticallyLinked(plugin);
77
78            tracing::debug!(
79                "Loaded statically linked plugin: {}",
80                std::any::type_name::<T>()
81            );
82
83            Ok((plugin.target(), plugin))
84        });
85    }
86
87    pub fn load_dynamically_linked_plugin(&mut self, path: PathBuf) {
94        self.plugins.spawn(async move {
95            match path.extension() {
96                Some(ext) => {
97                    if ext == std::env::consts::DLL_EXTENSION {
98                        let path_buf = path.clone();
99                        let (library, constructor) = tokio::task::spawn_blocking(move || {
100                            let library = unsafe {
101                                #[cfg(target_family = "unix")]
102                                let library = libloading::os::unix::Library::open(
103                                    Some(path_buf.clone()),
104                                    libloading::os::unix::RTLD_NOW | libloading::os::unix::RTLD_GLOBAL,
105                                )
106                                .wrap_err(format!("Failed to load path {:?}", path_buf))?;
107
108
109                                #[cfg(not(target_family = "unix"))]
110                                let library = Library::new(path_buf.clone())
111                                    .wrap_err(format!("Failed to load path {:?}", path_buf))?;
112
113                                library
114                            };
115
116                            let constructor = unsafe {
117                                library
118                                    .get::<*mut DynamicallyLinkedFileExtPluginInstance>(
119                                        b"IRIDIS_FILE_EXT_PLUGIN",
120                                    )
121                                    .wrap_err(format!(
122                                        "Failed to load symbol 'IRIDIS_FILE_EXT_PLUGIN' from cdylib {:?}",
123                                        path_buf
124                                    ))?
125                                    .read()
126                            };
127
128                            Ok::<_, eyre::Report>((library, constructor))
129                        })
130                        .await??;
131
132                        let plugin = RuntimeFileExt::DynamicallyLinked(
133                            DynamicallyLinkedFileExtPlugin::new(
134                                (constructor)().await?.wrap_err(format!(
135                                    "Failed to load dynamically linked plugin '{:?}'",
136                                    path,
137                                ))?,
138                                library,
139                            ),
140                        );
141
142                        tracing::debug!(
143                            "Loaded dynamically linked plugin from path: {}",
144                            path.display()
145                        );
146
147                        Ok((plugin.target(), plugin))
148                    } else {
149                        Err(eyre::eyre!("Extension '{:?}' is not supported", ext))
150                    }
151                }
152                _ => Err(eyre::eyre!("Unsupported path '{:?}'", path)),
153            }
154        });
155    }
156
157    pub async fn finish(mut self) -> Result<HashMap<String, Arc<RuntimeFileExt>>> {
160        let mut plugins = HashMap::new();
161
162        while let Some(result) = self.plugins.join_next().await {
163            let (targets, plugin) = result??;
164
165            let plugin = Arc::new(plugin);
166
167            for target in targets {
168                plugins.insert(target, plugin.clone());
169            }
170        }
171
172        Ok(plugins)
173    }
174}
175
176pub struct DynamicallyLinkedFileExtPlugin {
183    pub handle: ManuallyDrop<Box<dyn FileExtPlugin>>,
184
185    #[cfg(not(target_family = "unix"))]
186    pub library: ManuallyDrop<libloading::Library>,
187    #[cfg(target_family = "unix")]
188    pub library: ManuallyDrop<libloading::os::unix::Library>,
189}
190
191impl DynamicallyLinkedFileExtPlugin {
192    pub fn new(
195        handle: Box<dyn FileExtPlugin>,
196        #[cfg(not(target_family = "unix"))] library: libloading::Library,
197        #[cfg(target_family = "unix")] library: libloading::os::unix::Library,
198    ) -> Self {
199        Self {
200            handle: ManuallyDrop::new(handle),
201            library: ManuallyDrop::new(library),
202        }
203    }
204}
205
206impl Drop for DynamicallyLinkedFileExtPlugin {
207    fn drop(&mut self) {
208        unsafe {
209            ManuallyDrop::drop(&mut self.handle);
210            ManuallyDrop::drop(&mut self.library);
211        }
212    }
213}
214
215pub enum RuntimeFileExt {
218    StaticallyLinked(Box<dyn FileExtPlugin>),
219    DynamicallyLinked(DynamicallyLinkedFileExtPlugin),
220}
221
222impl RuntimeFileExt {
223    pub fn target(&self) -> Vec<String> {
225        match self {
226            RuntimeFileExt::StaticallyLinked(plugin) => plugin.target(),
227            RuntimeFileExt::DynamicallyLinked(plugin) => plugin.handle.target(),
228        }
229    }
230
231    pub async fn load(
233        &self,
234        path: PathBuf,
235        inputs: Inputs,
236        outputs: Outputs,
237        queries: Queries,
238        queryables: Queryables,
239        configuration: serde_yml::Value,
240    ) -> Result<RuntimeNode> {
241        match self {
242            RuntimeFileExt::StaticallyLinked(plugin) => {
243                plugin
244                    .load(path, inputs, outputs, queries, queryables, configuration)
245                    .await?
246            }
247            RuntimeFileExt::DynamicallyLinked(plugin) => {
248                plugin
249                    .handle
250                    .load(path, inputs, outputs, queries, queryables, configuration)
251                    .await?
252            }
253        }
254    }
255}