use axum::http::HeaderValue;
use parking_lot::RwLock;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::RenderResponse;
pub struct InertiaExtension {
props: RwLock<serde_json::Value>,
flash_props: RwLock<Option<serde_json::Value>>,
deferred_props: RwLock<Vec<&'static str>>,
merge_props: RwLock<Vec<&'static str>>,
component: RwLock<Option<&'static str>>,
redirect: RwLock<Option<HeaderValue>>,
encrypt_history: AtomicBool,
clear_history: AtomicBool,
version: RwLock<Option<&'static str>>,
html_sandwich: RwLock<(String, String)>,
}
impl Default for InertiaExtension {
fn default() -> Self {
Self {
flash_props: RwLock::new(None),
props: RwLock::new(serde_json::json!({
"errors": {}
})),
deferred_props: RwLock::new(Vec::new()),
merge_props: RwLock::new(Vec::new()),
component: RwLock::new(None),
redirect: RwLock::new(None),
encrypt_history: AtomicBool::new(false),
clear_history: AtomicBool::new(false),
version: RwLock::new(None),
html_sandwich: RwLock::new(("".to_owned(), "".to_owned())),
}
}
}
impl InertiaExtension {
pub fn add_prop(&self, key: &'static str, value: serde_json::Value) {
let mut props = self.props.write();
props[key] = value;
}
pub fn set_version(&self, v: &'static str) {
let mut version = self.version.write();
*version = Some(v);
}
pub async fn render_html(&self, rr: RenderResponse<'_>) -> String {
let rr = serde_json::to_string(&rr).expect("props should be serializable");
let rr = askama_escape::escape(&rr, askama_escape::Html);
let (head, tail) = &*self.html_sandwich.read();
format!("{head}{rr}{tail}")
}
pub fn add_flash_prop(&self, key: &'static str, value: serde_json::Value) {
let mut props = self.flash_props.write();
if props.is_none() {
*props = Some(serde_json::json!({}));
}
let props = unsafe { Option::unwrap_unchecked(props.as_mut()) };
props[key] = value;
}
pub fn take_flash_props(&self) -> Option<serde_json::Value> {
let mut props = self.flash_props.write();
props.take()
}
pub fn take_props(&self) -> serde_json::Value {
let mut props = self.props.write();
props.take()
}
pub fn extend_props(&self, props: serde_json::Value) {
let mut base_props = self.props.write();
if let (Some(base), serde_json::Value::Object(map)) = (base_props.as_object_mut(), props) {
for (k, v) in map {
base.insert(k, v);
}
}
}
pub fn set_html_sandwich(&self, start: String, end: String) {
let mut hs = self.html_sandwich.write();
*hs = (start, end);
}
pub fn component(&self) -> Option<&'static str> {
*self.component.read()
}
pub fn set_component(&self, component: &'static str) {
let mut comp = self.component.write();
*comp = Some(component);
}
pub fn take_redirect(&self) -> Option<HeaderValue> {
let mut red = self.redirect.write();
red.take()
}
pub fn set_redirect(&self, redirect: &str) {
let header = HeaderValue::from_str(redirect).expect("redirect header should be valid");
let mut red = self.redirect.write();
*red = Some(header);
}
pub fn encrypt_history(&self, encrypt: bool) {
self.encrypt_history.store(encrypt, Ordering::SeqCst);
}
pub fn get_encrypt_history(&self) -> bool {
self.encrypt_history.load(Ordering::SeqCst)
}
pub fn clear_history(&self, clear: bool) {
self.clear_history.store(clear, Ordering::SeqCst);
}
pub fn get_clear_history(&self) -> bool {
self.clear_history.load(Ordering::SeqCst)
}
pub fn defer_prop(&self, prop: &'static str) {
let mut deferred = self.deferred_props.write();
deferred.push(prop);
}
pub fn merge_prop(&self, prop: &'static str) {
let mut merge = self.merge_props.write();
merge.push(prop);
}
pub fn version(&self) -> Option<&'static str> {
*self.version.read()
}
}