pyo3_matrix_synapse_module/
lib.rs

1// Copyright 2022 Quentin Gliech
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![forbid(unsafe_code)]
16#![deny(clippy::all)]
17#![warn(clippy::pedantic)]
18
19use bytes::{Buf, Bytes};
20use http::{Request, Response};
21use http_body::{Body, Full};
22use pyo3::{exceptions::PyValueError, types::PyType, FromPyObject, PyAny, PyErr, PyResult};
23use pyo3_twisted_web::Resource;
24use serde::Deserialize;
25use tower_service::Service;
26
27mod synapse {
28    pyo3::import_exception!(synapse.module_api.errors, ConfigError);
29}
30
31pub struct ModuleApi<'a> {
32    inner: &'a PyAny,
33}
34
35impl<'a> ModuleApi<'a> {
36    /// Register a [`Service`] to handle a path
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if the call to `ModuleApi.register_web_resource` failed
41    pub fn register_web_service<S, B, E>(&self, path: &str, service: S) -> PyResult<()>
42    where
43        S: Service<Request<Full<Bytes>>, Response = Response<B>, Error = E>
44            + Clone
45            + Send
46            + 'static,
47        S::Future: Send,
48        B: Body + Send + 'static,
49        B::Data: Buf + 'static,
50        B::Error: Into<PyErr> + 'static,
51        E: Into<PyErr> + 'static,
52    {
53        self.inner.call_method1(
54            "register_web_resource",
55            (path, Resource::from_service(service)),
56        )?;
57        Ok(())
58    }
59}
60
61impl<'a> FromPyObject<'a> for ModuleApi<'a> {
62    fn extract(inner: &'a PyAny) -> PyResult<Self> {
63        let module_api_cls = inner
64            .py()
65            .import("synapse.module_api")?
66            .getattr("ModuleApi")?
67            .downcast::<PyType>()?;
68
69        if inner.is_instance(module_api_cls)? {
70            Ok(Self { inner })
71        } else {
72            Err(PyValueError::new_err(
73                "Object is not a synapse.module_api.ModuleApi",
74            ))
75        }
76    }
77}
78
79/// Convert a dict to `T` via `serde_json`, useful for implementing `parse_config`
80///
81/// # Errors
82///
83/// Returns an error if it failed to convert the dict
84pub fn parse_config<'a, T: Deserialize<'a>>(config: &'a PyAny) -> PyResult<T> {
85    let py = config.py();
86    let config: &str = py
87        .import("json")?
88        .call_method1("dumps", (config,))?
89        .extract()?;
90
91    let deserializer = &mut serde_json::Deserializer::from_str(config);
92    serde_path_to_error::deserialize(deserializer).map_err(|err| {
93        // Figure out the path where the error happened using `serde_path_to_error`
94        // XXX: This is probably good enough for now
95        let path: Vec<String> = err
96            .path()
97            .to_string()
98            .split('.')
99            .map(ToOwned::to_owned)
100            .collect();
101
102        // XXX: This is ugly, but it removes the " at line X column Y" from serde_json's errors
103        let mut message = err.into_inner().to_string();
104        if let Some(idx) = message.rfind(" at line ") {
105            message.truncate(idx);
106        }
107
108        synapse::ConfigError::new_err((message, path))
109    })
110}