fusen-common 0.8.12

fusen-common
Documentation
use std::{fs, sync::Arc};
use tokio::sync::mpsc;
use tracing::error;

pub struct HotConfigChangeListener {
    pub sender: mpsc::Sender<ConfigResponse>,
}

pub struct ConfigResponse {
    pub content_type: String,
    pub content: String,
}

impl ConfigResponse {
    pub fn content_type(&self) -> &String {
        &self.content_type
    }
    pub fn content(&self) -> &String {
        &self.content
    }
}

impl HotConfigChangeListener {
    pub fn new() -> (Self, mpsc::Receiver<ConfigResponse>) {
        let (sender, receiver) = mpsc::channel(1);
        (Self { sender }, receiver)
    }
}

pub struct ConfigManager<T> {
    pub config: Arc<T>,
    hot_config_sender: tokio::sync::mpsc::UnboundedSender<(
        HotConfigSender<T>,
        tokio::sync::oneshot::Sender<HotConfigReceiver<T>>,
    )>,
}

pub enum HotConfigSender<T> {
    GET,
    CHANGE(T),
}

pub enum HotConfigReceiver<T> {
    GET(std::sync::Arc<T>),
    CHANGE,
}

impl<T> ConfigManager<T>
where
    T: serde::de::DeserializeOwned + Send + Sync + 'static,
{
    pub fn build_hot_config(
        config: T,
        mut listener: tokio::sync::mpsc::Receiver<crate::config::ConfigResponse>,
    ) -> Result<Self, crate::error::Error> {
        let (sender, mut receive) = tokio::sync::mpsc::unbounded_channel::<(
            HotConfigSender<T>,
            tokio::sync::oneshot::Sender<HotConfigReceiver<T>>,
        )>();
        let config = Arc::new(config);
        let mut cache = config.clone();
        tokio::spawn(async move {
            loop {
                while let Some((msg, sender)) = receive.recv().await {
                    match msg {
                        HotConfigSender::GET => {
                            let _ = sender.send(HotConfigReceiver::GET(cache.clone()));
                        }
                        HotConfigSender::CHANGE(ident) => {
                            cache = std::sync::Arc::new(ident);
                            let _ = sender.send(HotConfigReceiver::CHANGE);
                        }
                    }
                }
            }
        });
        let sender_clone = sender.clone();
        tokio::spawn(async move {
            let sender_clone = sender_clone;
            while let Some(config_response) = listener.recv().await {
                let Ok(ident) = config_build(config_response) else {
                    error!("config_build error!");
                    continue;
                };
                let (sender, receive) = tokio::sync::oneshot::channel();
                let _ = sender_clone.send((HotConfigSender::CHANGE(ident), sender));
                if receive.await.is_err() {
                    error!("receive error!");
                }
            }
        });
        let config_manage = ConfigManager {
            config,
            hot_config_sender: sender,
        };
        Ok(config_manage)
    }

    pub async fn get_hot_config(&self) -> std::sync::Arc<T> {
        let (sender, receive) = tokio::sync::oneshot::channel();
        let _ = self
            .hot_config_sender
            .clone()
            .send((HotConfigSender::GET, sender));
        let Ok(HotConfigReceiver::GET(ident)) = receive.await else {
            panic!("impossibility !!!");
        };
        ident
    }
}

pub fn config_build<T: serde::de::DeserializeOwned>(
    config_response: ConfigResponse,
) -> Result<T, crate::error::Error> {
    match config_response.content_type().as_str() {
        "toml" => get_toml_by_context(config_response.content()),
        "yaml" => get_yaml_by_context(config_response.content()),
        r#type => Err(crate::error::Error::MessageError(r#type.to_string())),
    }
}

pub fn get_toml_by_context<T: serde::de::DeserializeOwned>(
    toml_context: &str,
) -> Result<T, crate::error::Error> {
    // 解析 TOML 文件内容
    let parsed_toml: toml::Value = toml_context
        .parse()
        .map_err(|error| crate::error::Error::ConfigError(Box::new(error)))?;
    let json = serde_json::json!(parsed_toml);
    T::deserialize(json).map_err(|error| crate::error::Error::ConfigError(Box::new(error)))
}

pub fn get_yaml_by_context<T: serde::de::DeserializeOwned>(
    yaml_context: &str,
) -> Result<T, crate::error::Error> {
    // 解析 yaml 文件内容
    let parsed_toml: serde_yaml::Value = serde_yaml::from_str(yaml_context)
        .map_err(|error| crate::error::Error::ConfigError(Box::new(error)))?;
    T::deserialize(parsed_toml).map_err(|error| crate::error::Error::ConfigError(Box::new(error)))
}

pub fn get_config_by_path<T: serde::de::DeserializeOwned>(
    path: &str,
) -> Result<T, crate::error::Error> {
    let contents = fs::read_to_string(path).unwrap_or_else(|_| panic!("read path erro : {path:?}"));
    let file_type: Vec<&str> = path.split('.').collect();
    match file_type[file_type.len() - 1].as_bytes() {
        b"toml" => get_toml_by_context(&contents),
        b"yaml" => get_yaml_by_context(&contents),
        file_type => Err(crate::error::Error::MessageError(format!(
            "not support {file_type:?}"
        ))),
    }
}