Skip to main content

fbc_starter/
state.rs

1use chrono::Local;
2#[cfg(feature = "redis")]
3use deadpool_redis::Pool;
4#[cfg(feature = "mysql")]
5use sqlx::mysql::MySqlPool;
6#[cfg(feature = "postgres")]
7use sqlx::postgres::PgPool;
8#[cfg(feature = "sqlite")]
9use sqlx::sqlite::SqlitePool;
10use std::sync::Arc;
11
12use dashmap::DashMap;
13use std::any::{Any, TypeId};
14
15/// 应用状态(包含数据库连接池、Redis 客户端、Kafka 消息生产者/消费者、用户扩展等)
16#[derive(Clone)]
17pub struct AppState {
18    pub start_time: chrono::DateTime<Local>,
19    #[cfg(feature = "mysql")]
20    pub mysql: Option<Arc<MySqlPool>>,
21    #[cfg(feature = "postgres")]
22    pub postgres: Option<Arc<PgPool>>,
23    #[cfg(feature = "sqlite")]
24    pub sqlite: Option<Arc<SqlitePool>>,
25    #[cfg(feature = "redis")]
26    pub redis: Option<Pool>,
27    #[cfg(feature = "producer")]
28    pub message_producer: Option<crate::messaging::MessageProducerType>,
29    #[cfg(feature = "consumer")]
30    pub message_consumer: Option<crate::messaging::MessageConsumerType>,
31    /// 用户扩展状态容器(类型安全的 AnyMap)
32    extensions: Arc<DashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
33}
34
35impl AppState {
36    pub fn new() -> Self {
37        Self {
38            start_time: Local::now(),
39            #[cfg(feature = "mysql")]
40            mysql: None,
41            #[cfg(feature = "postgres")]
42            postgres: None,
43            #[cfg(feature = "sqlite")]
44            sqlite: None,
45            #[cfg(feature = "redis")]
46            redis: None,
47            #[cfg(feature = "producer")]
48            message_producer: None,
49            #[cfg(feature = "consumer")]
50            message_consumer: None,
51            extensions: Arc::new(DashMap::new()),
52        }
53    }
54
55    /// 设置用户扩展状态(类型安全)
56    ///
57    /// # 示例
58    /// ```ignore
59    /// struct MyConfig { pub api_key: String }
60    /// state.set_extension(MyConfig { api_key: "xxx".into() });
61    /// ```
62    pub fn set_extension<T: Any + Send + Sync + 'static>(&self, value: T) {
63        self.extensions.insert(TypeId::of::<T>(), Arc::new(value));
64    }
65
66    /// 获取用户扩展状态(类型安全)
67    ///
68    /// # 示例
69    /// ```ignore
70    /// let config = state.get_extension::<MyConfig>();
71    /// ```
72    pub fn get_extension<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
73        self.extensions
74            .get(&TypeId::of::<T>())
75            .and_then(|entry| entry.value().clone().downcast::<T>().ok())
76    }
77
78    /// 检查是否存在指定类型的扩展
79    pub fn has_extension<T: Any + Send + Sync + 'static>(&self) -> bool {
80        self.extensions.contains_key(&TypeId::of::<T>())
81    }
82
83    /// 带扩展的链式构建
84    pub fn with_extension<T: Any + Send + Sync + 'static>(self, value: T) -> Self {
85        self.set_extension(value);
86        self
87    }
88
89    /// 设置 MySQL 连接池
90    #[cfg(feature = "mysql")]
91    pub fn with_mysql(mut self, pool: Arc<MySqlPool>) -> Self {
92        self.mysql = Some(pool);
93        self
94    }
95
96    /// 设置 Postgres 连接池
97    #[cfg(feature = "postgres")]
98    pub fn with_postgres(mut self, pool: Arc<PgPool>) -> Self {
99        self.postgres = Some(pool);
100        self
101    }
102
103    /// 设置 SQLite 连接池
104    #[cfg(feature = "sqlite")]
105    pub fn with_sqlite(mut self, pool: Arc<SqlitePool>) -> Self {
106        self.sqlite = Some(pool);
107        self
108    }
109
110    /// 批量设置三个数据库连接池
111    #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
112    pub fn with_database_pools(mut self, pools: crate::database::DatabasePools) -> Self {
113        #[cfg(feature = "mysql")]
114        if let Some(pool) = pools.mysql {
115            self.mysql = Some(Arc::new(pool));
116        }
117        #[cfg(feature = "postgres")]
118        if let Some(pool) = pools.postgres {
119            self.postgres = Some(Arc::new(pool));
120        }
121        #[cfg(feature = "sqlite")]
122        if let Some(pool) = pools.sqlite {
123            self.sqlite = Some(Arc::new(pool));
124        }
125        self
126    }
127
128    /// 设置 Redis 连接池
129    #[cfg(feature = "redis")]
130    pub fn with_redis(mut self, pool: Pool) -> Self {
131        self.redis = Some(pool);
132        self
133    }
134
135    /// 设置 Kafka 消息生产者
136    #[cfg(feature = "producer")]
137    pub fn with_message_producer(
138        mut self,
139        message_producer: crate::messaging::MessageProducerType,
140    ) -> Self {
141        self.message_producer = Some(message_producer);
142        self
143    }
144
145    /// 设置 Kafka 消息消费者
146    #[cfg(feature = "consumer")]
147    pub fn with_message_consumer(
148        mut self,
149        message_consumer: crate::messaging::MessageConsumerType,
150    ) -> Self {
151        self.message_consumer = Some(message_consumer);
152        self
153    }
154
155    /// 获取 MySQL 连接池
156    #[cfg(feature = "mysql")]
157    pub fn mysql(&self) -> crate::error::AppResult<Arc<MySqlPool>> {
158        self.mysql
159            .clone()
160            .ok_or(crate::error::AppError::DatabaseNotInitialized)
161    }
162
163    /// 获取 Postgres 连接池
164    #[cfg(feature = "postgres")]
165    pub fn postgres(&self) -> crate::error::AppResult<Arc<PgPool>> {
166        self.postgres
167            .clone()
168            .ok_or(crate::error::AppError::DatabaseNotInitialized)
169    }
170
171    /// 获取 SQLite 连接池
172    #[cfg(feature = "sqlite")]
173    pub fn sqlite(&self) -> crate::error::AppResult<Arc<SqlitePool>> {
174        self.sqlite
175            .clone()
176            .ok_or(crate::error::AppError::DatabaseNotInitialized)
177    }
178
179    /// 获取 Redis 连接
180    ///
181    /// 每次调用都会从连接池获取一个新的连接,支持并发操作
182    #[cfg(feature = "redis")]
183    pub async fn redis(&self) -> crate::error::AppResult<deadpool_redis::Connection> {
184        self.redis
185            .as_ref()
186            .ok_or(crate::error::AppError::RedisNotInitialized)?
187            .get()
188            .await
189            .map_err(|e| {
190                crate::error::AppError::Internal(anyhow::anyhow!("获取 Redis 连接失败: {}", e))
191            })
192    }
193
194    /// 获取 Kafka 消息生产者
195    #[cfg(feature = "producer")]
196    pub fn message_producer(
197        &self,
198    ) -> crate::error::AppResult<&crate::messaging::MessageProducerType> {
199        self.message_producer.as_ref().ok_or_else(|| {
200            crate::error::AppError::Internal(anyhow::anyhow!("Kafka 消息生产者未初始化"))
201        })
202    }
203
204    /// 获取 Kafka 消息消费者
205    #[cfg(feature = "consumer")]
206    pub fn message_consumer(
207        &self,
208    ) -> crate::error::AppResult<&crate::messaging::MessageConsumerType> {
209        self.message_consumer.as_ref().ok_or_else(|| {
210            crate::error::AppError::Internal(anyhow::anyhow!("Kafka 消息消费者未初始化"))
211        })
212    }
213}
214
215impl Default for AppState {
216    fn default() -> Self {
217        Self::new()
218    }
219}