fusen_common/
config.rs

1use std::{fs, sync::Arc};
2use tokio::sync::mpsc;
3use tracing::error;
4
5pub struct HotConfigChangeListener {
6    pub sender: mpsc::Sender<ConfigResponse>,
7}
8
9pub struct ConfigResponse {
10    pub content_type: String,
11    pub content: String,
12}
13
14impl ConfigResponse {
15    pub fn content_type(&self) -> &String {
16        &self.content_type
17    }
18    pub fn content(&self) -> &String {
19        &self.content
20    }
21}
22
23impl HotConfigChangeListener {
24    pub fn new() -> (Self, mpsc::Receiver<ConfigResponse>) {
25        let (sender, receiver) = mpsc::channel(1);
26        (Self { sender }, receiver)
27    }
28}
29
30pub struct ConfigManager<T> {
31    pub config: Arc<T>,
32    hot_config_sender: tokio::sync::mpsc::UnboundedSender<(
33        HotConfigSender<T>,
34        tokio::sync::oneshot::Sender<HotConfigReceiver<T>>,
35    )>,
36}
37
38pub enum HotConfigSender<T> {
39    GET,
40    CHANGE(T),
41}
42
43pub enum HotConfigReceiver<T> {
44    GET(std::sync::Arc<T>),
45    CHANGE,
46}
47
48impl<T> ConfigManager<T>
49where
50    T: serde::de::DeserializeOwned + Send + Sync + 'static,
51{
52    pub fn build_hot_config(
53        config: T,
54        mut listener: tokio::sync::mpsc::Receiver<crate::config::ConfigResponse>,
55    ) -> Result<Self, crate::error::Error> {
56        let (sender, mut receive) = tokio::sync::mpsc::unbounded_channel::<(
57            HotConfigSender<T>,
58            tokio::sync::oneshot::Sender<HotConfigReceiver<T>>,
59        )>();
60        let config = Arc::new(config);
61        let mut cache = config.clone();
62        tokio::spawn(async move {
63            loop {
64                while let Some((msg, sender)) = receive.recv().await {
65                    match msg {
66                        HotConfigSender::GET => {
67                            let _ = sender.send(HotConfigReceiver::GET(cache.clone()));
68                        }
69                        HotConfigSender::CHANGE(ident) => {
70                            cache = std::sync::Arc::new(ident);
71                            let _ = sender.send(HotConfigReceiver::CHANGE);
72                        }
73                    }
74                }
75            }
76        });
77        let sender_clone = sender.clone();
78        tokio::spawn(async move {
79            let sender_clone = sender_clone;
80            while let Some(config_response) = listener.recv().await {
81                let Ok(ident) = config_build(config_response) else {
82                    error!("config_build error!");
83                    continue;
84                };
85                let (sender, receive) = tokio::sync::oneshot::channel();
86                let _ = sender_clone.send((HotConfigSender::CHANGE(ident), sender));
87                if receive.await.is_err() {
88                    error!("receive error!");
89                }
90            }
91        });
92        let config_manage = ConfigManager {
93            config,
94            hot_config_sender: sender,
95        };
96        Ok(config_manage)
97    }
98
99    pub async fn get_hot_config(&self) -> std::sync::Arc<T> {
100        let (sender, receive) = tokio::sync::oneshot::channel();
101        let _ = self
102            .hot_config_sender
103            .clone()
104            .send((HotConfigSender::GET, sender));
105        let Ok(HotConfigReceiver::GET(ident)) = receive.await else {
106            panic!("impossibility !!!");
107        };
108        ident
109    }
110}
111
112pub fn config_build<T: serde::de::DeserializeOwned>(
113    config_response: ConfigResponse,
114) -> Result<T, crate::error::Error> {
115    match config_response.content_type().as_str() {
116        "toml" => get_toml_by_context(config_response.content()),
117        "yaml" => get_yaml_by_context(config_response.content()),
118        r#type => Err(crate::error::Error::MessageError(r#type.to_string())),
119    }
120}
121
122pub fn get_toml_by_context<T: serde::de::DeserializeOwned>(
123    toml_context: &str,
124) -> Result<T, crate::error::Error> {
125    // 解析 TOML 文件内容
126    let parsed_toml: toml::Value = toml_context
127        .parse()
128        .map_err(|error| crate::error::Error::ConfigError(Box::new(error)))?;
129    let json = serde_json::json!(parsed_toml);
130    T::deserialize(json).map_err(|error| crate::error::Error::ConfigError(Box::new(error)))
131}
132
133pub fn get_yaml_by_context<T: serde::de::DeserializeOwned>(
134    yaml_context: &str,
135) -> Result<T, crate::error::Error> {
136    // 解析 yaml 文件内容
137    let parsed_toml: serde_yaml::Value = serde_yaml::from_str(yaml_context)
138        .map_err(|error| crate::error::Error::ConfigError(Box::new(error)))?;
139    T::deserialize(parsed_toml).map_err(|error| crate::error::Error::ConfigError(Box::new(error)))
140}
141
142pub fn get_config_by_path<T: serde::de::DeserializeOwned>(
143    path: &str,
144) -> Result<T, crate::error::Error> {
145    let contents = fs::read_to_string(path).unwrap_or_else(|_| panic!("read path erro : {path:?}"));
146    let file_type: Vec<&str> = path.split('.').collect();
147    match file_type[file_type.len() - 1].as_bytes() {
148        b"toml" => get_toml_by_context(&contents),
149        b"yaml" => get_yaml_by_context(&contents),
150        file_type => Err(crate::error::Error::MessageError(format!(
151            "not support {file_type:?}"
152        ))),
153    }
154}