Skip to main content

br_db/
pools.rs

1use crate::config::Connection;
2
3// 前向声明,避免循环导入
4pub struct Db;
5use json::JsonValue;
6use log::warn;
7use std::sync::{mpsc, Arc, Mutex};
8use std::thread;
9use std::thread::JoinHandle;
10
11type Job = Box<dyn FnOnce(usize) -> JsonValue + 'static + Send>;
12
13enum Message {
14    End,
15    NewJob(Job),
16}
17
18struct Worker {
19    _id: usize,
20    t: Option<JoinHandle<Vec<JsonValue>>>,
21}
22
23impl Worker {
24    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
25        let t = thread::spawn(move || {
26            let mut list = vec![];
27            // 计算loop执行次数
28            let mut count_flag: i64 = 0;
29            loop {
30                // 改进错误处理,避免 unwrap
31                let message = match receiver.lock() {
32                    Ok(guard) => match guard.recv() {
33                        Ok(msg) => msg,
34                        Err(e) => {
35                            warn!("Worker {} 接收消息失败: {}", id, e);
36                            break;
37                        }
38                    },
39                    Err(e) => {
40                        warn!("Worker {} 获取锁失败: {}", id, e);
41                        break;
42                    }
43                };
44
45                match message {
46                    Message::NewJob(job) => {
47                        list.push(job(id));
48                        count_flag += 1;
49                        if count_flag == 10000 {
50                            warn!("Worker {} 循环次数: 1w,强制退出", id);
51                            break;
52                        }
53                    }
54                    Message::End => {
55                        break;
56                    }
57                }
58            }
59            list
60        });
61        Worker {
62            _id: id,
63            t: Some(t),
64        }
65    }
66}
67
68pub struct Pool {
69    workers: Vec<Worker>,
70    max_workers: usize,
71    sender: mpsc::Sender<Message>,
72}
73
74impl Pool {
75    pub fn new(max_workers: usize) -> Pool {
76        if max_workers == 0 {
77            println!("max_workers 必须大于0")
78        }
79        let (tx, rx) = mpsc::channel();
80        let mut workers = Vec::with_capacity(max_workers);
81        let receiver = Arc::new(Mutex::new(rx));
82        for i in 0..max_workers {
83            workers.push(Worker::new(i, Arc::clone(&receiver)));
84        }
85        Pool {
86            workers,
87            max_workers,
88            sender: tx,
89        }
90    }
91    pub fn execute<F>(&self, f: F)
92    where
93        F: 'static + Send + FnOnce(usize) -> JsonValue,
94    {
95        let job = Message::NewJob(Box::new(f));
96        if let Err(e) = self.sender.send(job) {
97            warn!("发送任务失败: {}", e);
98        }
99    }
100    pub fn end(&mut self) -> JsonValue {
101        // 改进错误处理
102        for _ in 0..self.max_workers {
103            if let Err(e) = self.sender.send(Message::End) {
104                warn!("发送结束消息失败: {}", e);
105            }
106        }
107        let mut list = Vec::new();
108        for w in self.workers.iter_mut() {
109            if let Some(t) = w.t.take() {
110                match t.join() {
111                    Ok(data) => {
112                        list.extend(data);
113                    }
114                    Err(e) => {
115                        warn!("线程连接失败: {:?}", e);
116                    }
117                }
118            }
119        }
120        JsonValue::from(list)
121    }
122    pub fn insert_all(&mut self) -> (Vec<String>, String) {
123        for _ in 0..self.max_workers {
124            if let Err(e) = self.sender.send(Message::End) {
125                warn!("发送结束消息失败: {}", e);
126            }
127        }
128        // 预分配容量,减少重新分配
129        let mut id = Vec::new();
130        let mut list_parts = Vec::new();
131
132        for w in self.workers.iter_mut() {
133            if let Some(t) = w.t.take() {
134                match t.join() {
135                    Ok(data) => {
136                        for item in data.iter() {
137                            id.push(item[0].to_string());
138                            list_parts.push(item[1].to_string());
139                        }
140                    }
141                    Err(e) => {
142                        warn!("线程连接失败: {:?}", e);
143                    }
144                }
145            }
146        }
147        // 使用 join 而不是 format!,更高效
148        let list = list_parts.join(",");
149        (id, list)
150    }
151}
152
153/// 公共模块客户端
154pub struct PubModClient {
155    pub database: String,
156    pub data: JsonValue,
157    pub conn: Connection,
158    pub db: DbOption,
159}
160
161pub enum DbOption {
162    Some(Db),
163    None,
164}
165
166impl PubModClient {
167    pub fn new() -> Self {
168        Self {
169            database: String::new(),
170            data: JsonValue::Null,
171            conn: Connection::new(""),
172            db: DbOption::None,
173        }
174    }
175}
176
177impl Default for PubModClient {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183// 向后兼容的类型别名
184#[allow(non_camel_case_types)]
185pub type Pub_Mod_Client = PubModClient;