1pub use anyhow;
2pub use salvo;
3pub use salvo::catcher::Catcher;
4#[cfg(feature = "http3")]
5pub use salvo::conn::rustls::{Keycert, RustlsConfig};
6#[cfg(feature = "http3")]
7use salvo::conn::tcp::TcpAcceptor;
8
9use salvo::prelude::*;
10use salvo::serve_static::StaticDir;
11pub use serde_json::{self, Value};
12use std::{collections::HashMap, marker::PhantomData, sync::Arc};
13pub use tera::{self, Context, Filter, Function, Tera};
14pub use tokio::{self};
15
16type TeraFunctionMap = HashMap<String, Arc<dyn Function + 'static>>;
17type TeraFilterMap = HashMap<String, Arc<dyn Filter + 'static>>;
18type MetaInfoCollector =
19 Option<Arc<dyn Fn(&Request) -> HashMap<String, Value> + 'static + Send + Sync>>;
20
21type HookViewPathHandlerType = Option<Arc<dyn Fn(&mut Request, String) -> String + Send + Sync>>;
22struct CallableObjectForTera<F: ?Sized>(Arc<F>);
23
24impl<F: Function + ?Sized> Function for CallableObjectForTera<F> {
25 fn call(&self, args: &HashMap<String, Value>) -> tera::Result<Value> {
26 self.0.call(args)
27 }
28}
29
30impl<F: Filter + ?Sized> Filter for CallableObjectForTera<F> {
31 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
32 self.0.filter(value, args)
33 }
34}
35
36pub struct Http3Certification {
37 pub cert: std::path::PathBuf,
38 pub key: std::path::PathBuf,
39}
40
41pub struct SSRender<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> = anyhow::Error> {
42 pub_assets_dir_name: String,
43 tmpl_dir_name: String,
44 host: String,
45 tmpl_func_map: TeraFunctionMap,
46 tmpl_filter_map: TeraFilterMap,
47 ctx_generator: MetaInfoCollector,
48 phantom_data_: PhantomData<ErrorWriter>,
49 default_view_file_postfix: String,
50 default_view_file_name: String,
51 listing_assets: bool,
52 default_asset_filename: Option<String>,
53 #[cfg(feature = "http3")]
54 use_http3: Option<Http3Certification>,
55 hook_view_path: HookViewPathHandlerType,
56}
57impl<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> + Send + Sync + 'static>
58 SSRender<ErrorWriter>
59{
60 pub fn new(host: &str) -> Self {
61 Self {
62 pub_assets_dir_name: "public".to_owned(),
63 tmpl_dir_name: "templates".to_owned(),
64 host: host.to_owned(),
65 tmpl_func_map: HashMap::new(),
66 tmpl_filter_map: HashMap::new(),
67 ctx_generator: None,
68 phantom_data_: PhantomData,
69 default_view_file_postfix: "html".to_owned(),
70 default_view_file_name: "index.html".to_owned(),
71 listing_assets: true,
72 default_asset_filename: None,
73 #[cfg(feature = "http3")]
74 use_http3: None,
75 hook_view_path: None,
76 }
77 }
78
79 pub fn host(&self) -> &str {
80 &self.host
81 }
82
83 pub fn set_pub_dir_name(&mut self, path: &str) {
84 self.pub_assets_dir_name = path.to_owned();
85 }
86
87 pub fn set_tmpl_dir_name(&mut self, path: &str) {
88 self.tmpl_dir_name = path.to_owned();
89 }
90
91 pub fn register_function<F: Function + 'static>(&mut self, k: String, f: F) {
92 self.tmpl_func_map.insert(k, Arc::new(f));
93 }
94
95 pub fn rm_registed_function(&mut self, k: String) {
96 self.tmpl_func_map.remove(&k);
97 }
98
99 pub fn registed_functions(&self) -> &TeraFunctionMap {
100 &self.tmpl_func_map
101 }
102
103 pub fn register_filter<F: Filter + 'static>(&mut self, k: String, f: F) {
104 self.tmpl_filter_map.insert(k, Arc::new(f));
105 }
106
107 pub fn rm_registed_filter(&mut self, k: String) {
108 self.tmpl_filter_map.remove(&k);
109 }
110
111 pub fn registed_filters(&self) -> &TeraFilterMap {
112 &self.tmpl_filter_map
113 }
114
115 pub fn pub_dir_name(&self) -> &str {
116 &self.pub_assets_dir_name
117 }
118
119 pub fn tmpl_dir_name(&self) -> &str {
120 &self.tmpl_dir_name
121 }
122
123 pub fn set_ctx_generator(
124 &mut self,
125 f: impl Fn(&Request) -> HashMap<String, Value> + 'static + Send + Sync,
126 ) {
127 self.ctx_generator = Some(Arc::new(f));
128 }
129
130 pub fn rm_ctx_generator(&mut self) {
131 self.ctx_generator = None;
132 }
133
134 pub fn gen_tera_builder(&self) -> TeraBuilder {
135 TeraBuilder::new(
136 format!("{}/**/*", self.tmpl_dir_name),
137 self.tmpl_func_map.clone(),
138 self.tmpl_filter_map.clone(),
139 self.ctx_generator.clone(),
140 )
141 }
142
143 pub fn set_default_file_postfix(&mut self, postfix: &str) {
144 self.default_view_file_postfix = postfix.to_owned();
145 }
146
147 pub fn default_file_postfix(&self) -> &str {
148 &self.default_view_file_postfix
149 }
150
151 pub fn set_listing_assets(&mut self, v: bool) {
152 self.listing_assets = v;
153 }
154
155 pub fn listing_assets(&self) -> bool {
156 self.listing_assets
157 }
158
159 pub fn set_default_assets_filename(&mut self, v: &str) {
160 self.default_asset_filename = Some(v.to_owned());
161 }
162
163 pub fn default_assets_filename(&self) -> &Option<String> {
164 &self.default_asset_filename
165 }
166 #[cfg(feature = "http3")]
167 pub fn set_use_http3(&mut self, cert: Http3Certification) {
168 self.use_http3 = Some(cert);
169 }
170 #[cfg(feature = "http3")]
171 pub fn use_http3(&self) -> Option<&Http3Certification> {
172 self.use_http3.as_ref()
173 }
174
175 pub fn set_hook_view_path<F: Fn(&mut Request, String) -> String + 'static + Send + Sync>(
176 &mut self,
177 hook: Option<F>,
178 ) {
179 match hook {
180 Some(f) => {
181 self.hook_view_path = Some(Arc::new(f));
182 }
183 None => {
184 self.hook_view_path = None;
185 }
186 }
187 }
188
189 pub fn hook_view_path(&self) -> &HookViewPathHandlerType {
190 &self.hook_view_path
191 }
192
193 pub async fn serve(&self, extend_router: Option<Router>, catcher: Option<Catcher>) {
194 let pub_assets_router = Router::with_path(format!("{}/<**>", self.pub_assets_dir_name))
195 .get(
196 StaticDir::new([&self.pub_assets_dir_name])
197 .defaults(match &self.default_asset_filename {
198 Some(v) => {
199 vec![v.to_owned()]
200 }
201 None => {
202 vec![]
203 }
204 })
205 .auto_list(self.listing_assets),
206 );
207 let view_router = Router::with_path("/<**rest_path>").get(ViewHandler::<ErrorWriter>::new(
208 self.gen_tera_builder(),
209 self.default_view_file_postfix.clone(),
210 self.default_view_file_name.clone(),
211 self.hook_view_path.clone(),
212 ));
213 let router = match extend_router {
216 Some(r) => r,
217 None => Router::new(),
218 };
219 let router = router.push(pub_assets_router);
220 let router = router.push(view_router);
221 #[cfg(feature = "http3")]
222 enum VariantAcceptor<U> {
223 NonHttp3(TcpAcceptor),
224 Http3(U),
225 }
226
227 #[cfg(feature = "http3")]
228 let var_acceptor = match self.use_http3.as_ref() {
229 Some(cert) => {
230 let cert_bytes = tokio::fs::read(&cert.cert).await.unwrap();
231 let key_bytes = tokio::fs::read(&cert.key).await.unwrap();
232 let config = RustlsConfig::new(
233 Keycert::new()
234 .cert(cert_bytes.as_slice())
235 .key(key_bytes.as_slice()),
236 );
237 let listener = TcpListener::new(self.host.clone()).rustls(config.clone());
238 let acceptor = QuinnListener::new(config, self.host.clone())
239 .join(listener)
240 .bind()
241 .await;
242 VariantAcceptor::Http3(acceptor)
243 }
244 None => {
245 let acceptor = TcpListener::new(&self.host).bind().await;
246 VariantAcceptor::NonHttp3(acceptor)
247 }
248 };
249
250 match catcher {
251 Some(catcher) => {
252 let service = Service::new(router).catcher(catcher);
253 #[cfg(feature = "http3")]
254 {
255 match var_acceptor {
256 VariantAcceptor::Http3(acceptor) => {
257 Server::new(acceptor).serve(service).await;
258 }
259 VariantAcceptor::NonHttp3(acceptor) => {
260 Server::new(acceptor).serve(service).await;
261 }
262 }
263 }
264 #[cfg(not(feature = "http3"))]
265 {
266 let acceptor = TcpListener::new(&self.host).bind().await;
267 Server::new(acceptor).serve(service).await;
268 }
269 }
270 None => {
271 #[cfg(feature = "http3")]
272 {
273 match var_acceptor {
274 VariantAcceptor::Http3(acceptor) => {
275 Server::new(acceptor).serve(router).await;
276 }
277 VariantAcceptor::NonHttp3(acceptor) => {
278 Server::new(acceptor).serve(router).await;
279 }
280 }
281 }
282 #[cfg(not(feature = "http3"))]
283 {
284 let acceptor = TcpListener::new(&self.host).bind().await;
285 Server::new(acceptor).serve(router).await;
286 }
287 }
288 };
289 }
290}
291
292pub struct TeraBuilder {
293 tpl_dir: String,
294 tpl_funcs: TeraFunctionMap,
295 tpl_filters: TeraFilterMap,
296 ctx_generator: MetaInfoCollector,
297}
298impl TeraBuilder {
299 pub fn new(
300 tpl_dir: String,
301 tpl_funcs: TeraFunctionMap,
302 tpl_filters: TeraFilterMap,
303 ctx_generator: MetaInfoCollector,
304 ) -> Self {
305 Self {
306 tpl_dir,
307 tpl_funcs,
308 tpl_filters,
309 ctx_generator,
310 }
311 }
312
313 fn register_utilities(&self, tera: &mut Tera) {
314 for (k, v) in &self.tpl_funcs {
315 tera.register_function(k, CallableObjectForTera(Arc::clone(v)));
316 }
317 for (k, v) in &self.tpl_filters {
318 tera.register_filter(k, CallableObjectForTera(Arc::clone(v)));
319 }
320 }
321
322 pub fn build(&self, ctx: Context) -> tera::Result<(Tera, Context)> {
323 let mut tera = Tera::new(&self.tpl_dir)?;
324 self.register_utilities(&mut tera);
325 tera.register_filter(
326 "json_decode",
327 |v: &Value, _args: &HashMap<String, Value>| -> tera::Result<Value> {
328 let v = v
329 .as_str()
330 .ok_or(tera::Error::msg("value must be a json object string"))?;
331 let v = serde_json::from_str::<Value>(v)?;
332 Ok(v)
333 },
334 );
335 tera.register_function("include_file", generate_include(tera.clone(), ctx.clone()));
336 Ok((tera, ctx))
337 }
338
339 pub fn gen_context(&self, req: &Request) -> Context {
340 match self.ctx_generator {
341 Some(ref collect) => {
342 let mut context = Context::new();
343 for (k, val) in collect(req) {
344 context.insert(k, &val);
345 }
346 context
347 }
348 None => Context::default(),
349 }
350 }
351}
352
353struct ViewHandler<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> = anyhow::Error> {
354 tera_builder: TeraBuilder,
355 phantom_data_: PhantomData<ErrorWriter>,
356 default_postfix: String,
357 default_view_file_name: String,
358 hook_view_path: HookViewPathHandlerType,
359}
360impl<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error>> ViewHandler<ErrorWriter> {
361 fn new(
362 tera_builder: TeraBuilder,
363 default_postfix: String,
364 default_view_file_name: String,
365 hook_view_path: HookViewPathHandlerType,
366 ) -> Self {
367 Self {
368 tera_builder,
369 phantom_data_: PhantomData,
370 default_postfix,
371 default_view_file_name,
372 hook_view_path,
373 }
374 }
375}
376#[handler]
377impl<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> + Send + Sync + 'static>
378 ViewHandler<ErrorWriter>
379{
380 async fn handle(
381 &self,
382 req: &mut Request,
383 _depot: &mut Depot,
384 res: &mut Response,
385 ) -> Result<(), ErrorWriter> {
386 let Some(path) = req.param::<String>("**rest_path") else {
387 res.status_code(StatusCode::BAD_REQUEST);
388 return Err(anyhow::format_err!("invalid request path").into());
389 };
390 let ctx = self.tera_builder.gen_context(req);
391 let path = if path.is_empty() {
392 self.default_view_file_name.to_string()
393 } else {
394 match path.rfind('.') {
395 Some(_) => path,
396 None => {
397 format!("{path}.{}", self.default_postfix)
398 }
399 }
400 };
401 let path = match &self.hook_view_path {
402 Some(f) => f(&mut *req, path),
403 None => path,
404 };
405 if !cfg!(debug_assertions) {
406 let (tera, ctx) = self.tera_builder.build(ctx.clone())?;
407 match tera.render(&path, &ctx) {
408 Ok(html) => {
409 res.render(Text::Html(html));
410 }
411 Err(e) => {
412 if let tera::ErrorKind::TemplateNotFound(_) = &e.kind {
413 res.status_code(StatusCode::NOT_FOUND);
414 } else {
415 res.status_code(StatusCode::BAD_REQUEST);
416 }
417 return Err(anyhow::format_err!("{}", e.to_string()).into());
418 }
419 };
420 } else {
421 match self.tera_builder.build(ctx.clone()) {
422 Ok((tera, ctx)) => match tera.render(&path, &ctx) {
423 Ok(s) => {
424 res.render(Text::Html(s));
425 }
426 Err(e) => {
427 if let tera::ErrorKind::TemplateNotFound(_) = &e.kind {
428 res.status_code(StatusCode::NOT_FOUND);
429 } else {
430 res.status_code(StatusCode::BAD_REQUEST);
431 }
432 return Err(anyhow::format_err!("{e:?}").into());
433 }
434 },
435 Err(e) => {
436 res.status_code(StatusCode::BAD_REQUEST);
437 return Err(anyhow::format_err!("{e:?}").into());
438 }
439 };
440 }
441 Ok(())
442 }
443}
444
445fn generate_include(tera: Tera, parent: Context) -> impl Function {
446 move |args: &HashMap<String, Value>| -> tera::Result<Value> {
447 let Some(file_path) = args.get("path") else {
448 return Err(tera::Error::msg("file does not exist in the template path"));
449 };
450 match args.get("context") {
451 Some(v) => {
452 let context_value = v
454 .as_str()
455 .ok_or(tera::Error::msg("context must be a json object string"))?;
456 let v = serde_json::from_str::<Value>(context_value)?;
457 let mut context = Context::from_value(serde_json::json!({ "context": v }))?;
458 let mut tera = tera.clone();
459 context.insert("__Parent", &parent.clone().into_json());
460 tera.register_function(
461 "include_file",
462 generate_include(tera.clone(), context.clone()),
463 );
464 let r = tera
465 .render(
466 file_path
467 .as_str()
468 .ok_or(tera::Error::msg("template render error"))?,
469 &context,
470 )?
471 .to_string();
472 Ok(Value::String(r))
473 }
474 None => {
475 let mut context =
476 Context::from_value(serde_json::json!({ "context": Value::Null }))?;
477 let mut tera = tera.clone();
478 context.insert("__Parent", &parent.clone().into_json());
479 tera.register_function(
480 "include_file",
481 generate_include(tera.clone(), context.clone()),
482 );
483 let r = tera
484 .render(
485 file_path
486 .as_str()
487 .ok_or(tera::Error::msg("template render error"))?,
488 &context,
489 )?
490 .to_string();
491 return Ok(Value::String(r));
492 }
493 }
494 }
495}
496
497#[macro_export]
498macro_rules! ssr_work {
499 ($e:expr, None, $catcher:expr) => {
500 $crate::tokio::runtime::Builder::new_multi_thread()
501 .enable_all()
502 .build()
503 .unwrap()
504 .block_on(async {
505 $e.serve(None, Some($catcher)).await;
506 });
507 };
508 ($e:expr, $router:expr, $catcher:expr) => {
509 $crate::tokio::runtime::Builder::new_multi_thread()
510 .enable_all()
511 .build()
512 .unwrap()
513 .block_on(async {
514 $e.serve(Some($router), Some($catcher)).await;
515 });
516 };
517 ($e:expr, $router:expr) => {
518 $crate::tokio::runtime::Builder::new_multi_thread()
519 .enable_all()
520 .build()
521 .unwrap()
522 .block_on(async {
523 $e.serve(Some($router), None).await;
524 });
525 };
526 ($e:expr) => {
527 $crate::tokio::runtime::Builder::new_multi_thread()
528 .enable_all()
529 .build()
530 .unwrap()
531 .block_on(async {
532 $e.serve(None, None).await;
533 });
534 };
535}