nostr_relay/
app.rs

1use crate::{setting::SettingWrapper, Extension, Extensions, Result, Server, Setting};
2use actix::Addr;
3use actix_cors::Cors;
4use actix_web::{
5    body::MessageBody,
6    dev::{ServiceFactory, ServiceRequest},
7    web, App as WebApp, HttpServer,
8};
9use nostr_db::Db;
10use parking_lot::RwLock;
11use std::{path::Path, sync::Arc};
12use tracing::info;
13
14pub mod route {
15    use crate::{App, Session};
16    use actix_web::http::header::{ACCEPT, LOCATION, UPGRADE};
17    use actix_web::{web, Error, HttpRequest, HttpResponse};
18    use actix_web_actors::ws;
19
20    fn get_ip(req: &HttpRequest, header: Option<&String>) -> Option<String> {
21        if let Some(header) = header {
22            // find from header list
23            // header.iter().find_map(|s| {
24            //     let hdr = req.headers().get(s)?.to_str().ok()?;
25            //     let val = hdr.split(',').next()?.trim();
26            //     Some(val.to_string())
27            // })
28            Some(
29                req.headers()
30                    .get(header)?
31                    .to_str()
32                    .ok()?
33                    .split(',')
34                    .next()?
35                    .trim()
36                    .to_string(),
37            )
38        } else {
39            Some(req.peer_addr()?.ip().to_string())
40        }
41    }
42
43    pub async fn websocket(
44        req: HttpRequest,
45        stream: web::Payload,
46        data: web::Data<App>,
47    ) -> Result<HttpResponse, Error> {
48        let r = data.setting.read();
49        let ip = get_ip(&req, r.network.real_ip_header.as_ref());
50        let max_size = r.limitation.max_message_length;
51        drop(r);
52
53        let session = Session::new(ip.unwrap_or_default(), data);
54
55        // ws::start(session, &req, stream)
56        // The default max frame size is 60k, change from setting.
57        ws::WsResponseBuilder::new(session, &req, stream)
58            .frame_size(max_size)
59            .start()
60    }
61
62    pub async fn information(
63        _req: HttpRequest,
64        _stream: web::Payload,
65        data: web::Data<App>,
66    ) -> Result<HttpResponse, Error> {
67        let r = data.setting.read();
68        Ok(HttpResponse::Ok()
69            .insert_header(("Content-Type", "application/nostr+json"))
70            .body(r.render_information()?))
71    }
72
73    pub async fn index(
74        req: HttpRequest,
75        stream: web::Payload,
76        data: web::Data<App>,
77    ) -> Result<HttpResponse, Error> {
78        let headers = req.headers();
79        if headers.contains_key(UPGRADE) {
80            return websocket(req, stream, data).await;
81        } else if let Some(accept) = headers.get(ACCEPT) {
82            if let Ok(accept) = accept.to_str() {
83                if accept.contains("application/nostr+json") {
84                    return information(req, stream, data).await;
85                }
86            }
87        }
88
89        let s = data.setting.read();
90        if let Some(site) = &s.network.index_redirect_to {
91            Ok(HttpResponse::Found()
92                .append_header((LOCATION, site.as_str()))
93                .finish())
94        } else {
95            Ok(HttpResponse::Ok().body(s.information.description.clone()))
96        }
97    }
98}
99
100/// App with data
101pub struct App {
102    pub server: Addr<Server>,
103    pub db: Arc<Db>,
104    pub setting: SettingWrapper,
105    pub extensions: Arc<RwLock<Extensions>>,
106}
107
108impl App {
109    /// data_path: overwrite setting data path
110    pub fn create<P: AsRef<Path>>(
111        setting_path: Option<P>,
112        watch: bool,
113        setting_env_prefix: Option<String>,
114        data_path: Option<P>,
115    ) -> Result<Self> {
116        let extensions = Arc::new(RwLock::new(Extensions::default()));
117        let c_extensions = Arc::clone(&extensions);
118        let env_notice = setting_env_prefix
119            .as_ref()
120            .map(|s| {
121                format!(
122                    ", config will be overrided by ENV seting with prefix `{}_`",
123                    s
124                )
125            })
126            .unwrap_or_default();
127
128        let setting = if watch && setting_path.is_some() {
129            let path = setting_path.as_ref().unwrap().as_ref();
130            info!("Watch config file {:?}{}", path, env_notice);
131            SettingWrapper::watch(path, setting_env_prefix, move |s| {
132                let mut w = c_extensions.write();
133                w.call_setting(s);
134            })?
135        } else if let Some(path) = setting_path {
136            info!("Load config {:?}{}", path.as_ref(), env_notice);
137            Setting::read(path.as_ref(), setting_env_prefix)?.into()
138        } else if let Some(prefix) = setting_env_prefix {
139            info!("Load default config{}", env_notice);
140            Setting::from_env(prefix)?.into()
141        } else {
142            info!("Load default config");
143            Setting::default().into()
144        };
145
146        {
147            info!("{:?}", setting.read());
148        }
149
150        let r = setting.read();
151        let path = data_path
152            .map(|p| p.as_ref().to_path_buf())
153            .unwrap_or_else(|| r.data.path.clone())
154            .join("events");
155        drop(r);
156        let db = Arc::new(Db::open(path)?);
157        db.check_schema()?;
158
159        let server = Server::create_with(db.clone(), setting.clone());
160
161        Ok(Self {
162            server,
163            setting,
164            db,
165            extensions,
166        })
167    }
168
169    pub fn add_extension<E: Extension + 'static>(self, mut ext: E) -> Self {
170        info!("Add extension {}", ext.name());
171        ext.setting(&self.setting);
172        {
173            let mut w = self.extensions.write();
174            w.add(ext);
175        }
176        self
177    }
178
179    pub fn web_app(
180        self,
181    ) -> WebApp<
182        impl ServiceFactory<
183            ServiceRequest,
184            Config = (),
185            Response = actix_web::dev::ServiceResponse<impl MessageBody>,
186            Error = actix_web::Error,
187            InitError = (),
188        >,
189    > {
190        create_web_app(web::Data::new(self))
191    }
192
193    pub fn web_server(self) -> Result<actix_web::dev::Server, std::io::Error> {
194        let r = self.setting.read();
195        let num = if r.thread.http == 0 {
196            num_cpus::get()
197        } else {
198            r.thread.http
199        };
200        let host = r.network.host.clone();
201        let port = r.network.port;
202        drop(r);
203        info!("Start http server {}:{}", host, port);
204        let data = web::Data::new(self);
205        Ok(HttpServer::new(move || create_web_app(data.clone()))
206            .workers(num)
207            .bind((host, port))?
208            .run())
209    }
210}
211
212pub fn create_web_app(
213    data: web::Data<App>,
214) -> WebApp<
215    impl ServiceFactory<
216        ServiceRequest,
217        Config = (),
218        Response = actix_web::dev::ServiceResponse<impl MessageBody>,
219        Error = actix_web::Error,
220        InitError = (),
221    >,
222> {
223    let app = WebApp::new();
224    let extensions = data.extensions.clone();
225    app.app_data(data)
226        .configure(|cfg| {
227            extensions.write().call_config_web(cfg);
228        })
229        .service(web::resource("/").route(web::get().to(route::index)))
230        .wrap(
231            Cors::default()
232                .send_wildcard()
233                .allow_any_header()
234                .allow_any_origin()
235                .allow_any_method()
236                .max_age(86_400), // 24h
237        )
238}
239
240#[cfg(test)]
241pub mod tests {
242    use std::time::Duration;
243
244    use crate::create_test_app;
245    use actix_rt::time::sleep;
246    use actix_test::read_body;
247    use actix_web::{
248        dev::Service,
249        test::{init_service, TestRequest},
250    };
251    use actix_web_actors::ws;
252    use anyhow::Result;
253    use bytes::Bytes;
254    use futures_util::{SinkExt as _, StreamExt as _};
255
256    #[actix_rt::test]
257    async fn relay_info() -> Result<()> {
258        let data = create_test_app("")?;
259        let app = init_service(data.web_app()).await;
260        sleep(Duration::from_millis(50)).await;
261        let req = TestRequest::with_uri("/")
262            .insert_header(("Accept", "application/nostr+json"))
263            .to_request();
264        let res = app.call(req).await.unwrap();
265        assert_eq!(res.status(), 200);
266        assert_eq!(
267            res.headers()
268                .get(actix_web::http::header::CONTENT_TYPE)
269                .unwrap(),
270            "application/nostr+json"
271        );
272        let result = read_body(res).await;
273        let result = String::from_utf8(result.to_vec())?;
274        assert!(result.contains("supported_nips"));
275        assert!(result.contains("limitation"));
276        Ok(())
277    }
278
279    #[actix_rt::test]
280    async fn connect_ws() -> Result<()> {
281        let mut srv = actix_test::start(|| {
282            let data = create_test_app("").unwrap();
283            data.web_app()
284        });
285
286        // client service
287        let mut framed = srv.ws_at("/").await.unwrap();
288
289        framed.send(ws::Message::Ping("text".into())).await?;
290        let item = framed.next().await.unwrap()?;
291        assert_eq!(item, ws::Frame::Pong(Bytes::copy_from_slice(b"text")));
292
293        framed
294            .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
295            .await?;
296        let item = framed.next().await.unwrap()?;
297        assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
298        Ok(())
299    }
300}