1use crate::core::{MemScopeError, MemScopeResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ThreadInfo {
14 pub thread_id: u64,
16 pub thread_name: Option<String>,
18 pub created_at: u64,
20 pub allocation_count: usize,
22 pub total_allocated: usize,
24 pub peak_memory: usize,
26 pub is_active: bool,
28}
29
30#[derive(Debug)]
32pub struct ThreadRegistry {
33 threads: Arc<Mutex<HashMap<u64, ThreadInfo>>>,
35 next_id: Arc<Mutex<u64>>,
37}
38
39impl ThreadRegistry {
40 pub fn new() -> Self {
42 Self {
43 threads: Arc::new(Mutex::new(HashMap::new())),
44 next_id: Arc::new(Mutex::new(1)),
45 }
46 }
47
48 pub fn next_id(&self) -> MemScopeResult<u64> {
49 let mut id = self.next_id.lock().map_err(|e| {
50 MemScopeError::system(
51 crate::core::error::SystemErrorType::Locking,
52 format!("Failed to acquire next_id lock: {}", e),
53 )
54 })?;
55 let current = *id;
56 *id = current.saturating_add(1);
57 Ok(current)
58 }
59
60 pub fn register_current_thread(&self) -> MemScopeResult<u64> {
62 let thread_id = std::thread::current().id();
63 let hash = self.hash_thread_id(&thread_id);
64 let timestamp = std::time::SystemTime::now()
65 .duration_since(std::time::UNIX_EPOCH)
66 .unwrap_or_default()
67 .as_nanos() as u64;
68
69 let mut threads = self.threads.lock().map_err(|e| {
70 MemScopeError::system(
71 crate::core::error::SystemErrorType::Locking,
72 format!("Failed to acquire threads lock: {}", e),
73 )
74 })?;
75 threads.entry(hash).or_insert_with(|| ThreadInfo {
76 thread_id: hash,
77 thread_name: Some(format!("{:?}", thread_id)),
78 created_at: timestamp,
79 allocation_count: 0,
80 total_allocated: 0,
81 peak_memory: 0,
82 is_active: true,
83 });
84 Ok(hash)
85 }
86
87 pub fn get_thread_info(&self, hash: u64) -> MemScopeResult<Option<ThreadInfo>> {
89 let threads = self.threads.lock().map_err(|e| {
90 MemScopeError::system(
91 crate::core::error::SystemErrorType::Locking,
92 format!("Failed to acquire threads lock: {}", e),
93 )
94 })?;
95 Ok(threads.get(&hash).cloned())
96 }
97
98 pub fn record_allocation(&self, hash: u64, size: usize) -> MemScopeResult<()> {
100 let mut threads = self.threads.lock().map_err(|e| {
101 MemScopeError::system(
102 crate::core::error::SystemErrorType::Locking,
103 format!("Failed to acquire threads lock: {}", e),
104 )
105 })?;
106 if let Some(info) = threads.get_mut(&hash) {
107 info.allocation_count += 1;
108 info.total_allocated += size;
109 }
110 Ok(())
111 }
112
113 pub fn update_peak_memory(&self, hash: u64, current_memory: usize) -> MemScopeResult<()> {
115 let mut threads = self.threads.lock().map_err(|e| {
116 MemScopeError::system(
117 crate::core::error::SystemErrorType::Locking,
118 format!("Failed to acquire threads lock: {}", e),
119 )
120 })?;
121 if let Some(info) = threads.get_mut(&hash) {
122 if current_memory > info.peak_memory {
123 info.peak_memory = current_memory;
124 }
125 }
126 Ok(())
127 }
128
129 pub fn mark_thread_inactive(&self, hash: u64) -> MemScopeResult<()> {
131 let mut threads = self.threads.lock().map_err(|e| {
132 MemScopeError::system(
133 crate::core::error::SystemErrorType::Locking,
134 format!("Failed to acquire threads lock: {}", e),
135 )
136 })?;
137 if let Some(info) = threads.get_mut(&hash) {
138 info.is_active = false;
139 }
140 Ok(())
141 }
142
143 pub fn get_all_threads(&self) -> MemScopeResult<Vec<ThreadInfo>> {
145 let threads = self.threads.lock().map_err(|e| {
146 MemScopeError::system(
147 crate::core::error::SystemErrorType::Locking,
148 format!("Failed to acquire threads lock: {}", e),
149 )
150 })?;
151 Ok(threads.values().cloned().collect())
152 }
153
154 pub fn get_active_threads(&self) -> MemScopeResult<Vec<ThreadInfo>> {
156 let threads = self.threads.lock().map_err(|e| {
157 MemScopeError::system(
158 crate::core::error::SystemErrorType::Locking,
159 format!("Failed to acquire threads lock: {}", e),
160 )
161 })?;
162 Ok(threads.values().filter(|t| t.is_active).cloned().collect())
163 }
164
165 pub fn len(&self) -> MemScopeResult<usize> {
167 let threads = self.threads.lock().map_err(|e| {
168 MemScopeError::system(
169 crate::core::error::SystemErrorType::Locking,
170 format!("Failed to acquire threads lock: {}", e),
171 )
172 })?;
173 Ok(threads.len())
174 }
175
176 pub fn is_empty(&self) -> MemScopeResult<bool> {
178 Ok(self.len()? == 0)
179 }
180
181 fn hash_thread_id(&self, thread_id: &std::thread::ThreadId) -> u64 {
183 crate::utils::thread_id_to_u64(*thread_id)
184 }
185}
186
187impl Default for ThreadRegistry {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
200 fn test_thread_registry_creation() {
201 let registry = ThreadRegistry::new();
202 assert!(registry.is_empty().unwrap(), "New registry should be empty");
203 }
204
205 #[test]
208 fn test_register_current_thread() {
209 let registry = ThreadRegistry::new();
210 let hash = registry.register_current_thread().unwrap();
211 assert!(hash > 0, "Thread hash should be positive");
212 assert_eq!(
213 registry.len().unwrap(),
214 1,
215 "Registry should have one thread"
216 );
217
218 let info = registry.get_thread_info(hash).unwrap();
219 assert!(info.is_some(), "Thread info should exist");
220 assert!(
221 info.unwrap().is_active,
222 "Thread should be active by default"
223 );
224 }
225
226 #[test]
229 fn test_record_allocation() {
230 let registry = ThreadRegistry::new();
231 let hash = registry.register_current_thread().unwrap();
232
233 registry.record_allocation(hash, 100).unwrap();
234 registry.record_allocation(hash, 200).unwrap();
235
236 let info = registry.get_thread_info(hash).unwrap().unwrap();
237 assert_eq!(info.allocation_count, 2, "Should have 2 allocations");
238 assert_eq!(info.total_allocated, 300, "Total allocated should be 300");
239 }
240
241 #[test]
244 fn test_update_peak_memory() {
245 let registry = ThreadRegistry::new();
246 let hash = registry.register_current_thread().unwrap();
247
248 registry.update_peak_memory(hash, 100).unwrap();
249 registry.update_peak_memory(hash, 200).unwrap();
250 registry.update_peak_memory(hash, 150).unwrap();
251
252 let info = registry.get_thread_info(hash).unwrap().unwrap();
253 assert_eq!(info.peak_memory, 200, "Peak memory should be 200");
254 }
255
256 #[test]
259 fn test_mark_thread_inactive() {
260 let registry = ThreadRegistry::new();
261 let hash = registry.register_current_thread().unwrap();
262
263 registry.mark_thread_inactive(hash).unwrap();
264
265 let info = registry.get_thread_info(hash).unwrap().unwrap();
266 assert!(!info.is_active, "Thread should be marked inactive");
267 }
268
269 #[test]
272 fn test_get_active_threads() {
273 let registry = ThreadRegistry::new();
274 let hash = registry.register_current_thread().unwrap();
275
276 let active = registry.get_active_threads().unwrap();
277 assert_eq!(active.len(), 1, "Should have one active thread");
278
279 registry.mark_thread_inactive(hash).unwrap();
280
281 let active = registry.get_active_threads().unwrap();
282 assert!(
283 active.is_empty(),
284 "Should have no active threads after marking inactive"
285 );
286 }
287
288 #[test]
291 fn test_get_all_threads() {
292 let registry = ThreadRegistry::new();
293 let _hash = registry.register_current_thread().unwrap();
294
295 let all = registry.get_all_threads().unwrap();
296 assert_eq!(all.len(), 1, "Should have one thread total");
297 }
298
299 #[test]
302 fn test_next_id() {
303 let registry = ThreadRegistry::new();
304 let id1 = registry.next_id().unwrap();
305 let id2 = registry.next_id().unwrap();
306 let id3 = registry.next_id().unwrap();
307
308 assert_eq!(id1, 1, "First ID should be 1");
309 assert_eq!(id2, 2, "Second ID should be 2");
310 assert_eq!(id3, 3, "Third ID should be 3");
311 }
312
313 #[test]
316 fn test_record_allocation_unknown_thread() {
317 let registry = ThreadRegistry::new();
318 let result = registry.record_allocation(99999, 100);
319 assert!(result.is_ok(), "Should not error on unknown thread");
320 }
321
322 #[test]
325 fn test_update_peak_memory_unknown_thread() {
326 let registry = ThreadRegistry::new();
327 let result = registry.update_peak_memory(99999, 100);
328 assert!(result.is_ok(), "Should not error on unknown thread");
329 }
330
331 #[test]
334 fn test_get_thread_info_unknown() {
335 let registry = ThreadRegistry::new();
336 let info = registry.get_thread_info(99999).unwrap();
337 assert!(info.is_none(), "Unknown thread should return None");
338 }
339
340 #[test]
343 fn test_default() {
344 let registry = ThreadRegistry::default();
345 assert!(
346 registry.is_empty().unwrap(),
347 "Default registry should be empty"
348 );
349 }
350
351 #[test]
354 fn test_thread_info_clone() {
355 let info = ThreadInfo {
356 thread_id: 1,
357 thread_name: Some("test".to_string()),
358 created_at: 12345,
359 allocation_count: 10,
360 total_allocated: 1000,
361 peak_memory: 500,
362 is_active: true,
363 };
364
365 let cloned = info.clone();
366 assert_eq!(
367 cloned.thread_id, info.thread_id,
368 "Cloned thread_id should match"
369 );
370 assert_eq!(
371 cloned.allocation_count, info.allocation_count,
372 "Cloned allocation_count should match"
373 );
374 }
375
376 #[test]
379 fn test_thread_info_debug() {
380 let info = ThreadInfo {
381 thread_id: 1,
382 thread_name: Some("test".to_string()),
383 created_at: 0,
384 allocation_count: 0,
385 total_allocated: 0,
386 peak_memory: 0,
387 is_active: true,
388 };
389
390 let debug_str = format!("{:?}", info);
391 assert!(
392 debug_str.contains("ThreadInfo"),
393 "Debug output should contain ThreadInfo"
394 );
395 assert!(
396 debug_str.contains("thread_id"),
397 "Debug output should contain thread_id"
398 );
399 }
400
401 #[test]
404 fn test_len_multiple_registrations() {
405 let registry = ThreadRegistry::new();
406 let hash1 = registry.register_current_thread().unwrap();
407 let hash2 = registry.register_current_thread().unwrap();
408
409 assert_eq!(hash1, hash2, "Same thread should get same hash");
410 assert_eq!(
411 registry.len().unwrap(),
412 1,
413 "Same thread registered twice should still be 1"
414 );
415 }
416}