Skip to main content

zsh/
zprof.rs

1//! Shell function profiling module - port of Modules/zprof.c
2//!
3//! Provides zprof builtin for profiling shell functions.
4
5use std::collections::HashMap;
6use std::time::Instant;
7
8/// Profile data for a single function
9#[derive(Debug, Clone)]
10pub struct ProfFunc {
11    pub name: String,
12    pub calls: u64,
13    pub total_time: f64,
14    pub self_time: f64,
15    pub num: usize,
16}
17
18impl ProfFunc {
19    pub fn new(name: &str) -> Self {
20        Self {
21            name: name.to_string(),
22            calls: 0,
23            total_time: 0.0,
24            self_time: 0.0,
25            num: 0,
26        }
27    }
28
29    pub fn avg_time(&self) -> f64 {
30        if self.calls > 0 {
31            self.total_time / self.calls as f64
32        } else {
33            0.0
34        }
35    }
36
37    pub fn avg_self(&self) -> f64 {
38        if self.calls > 0 {
39            self.self_time / self.calls as f64
40        } else {
41            0.0
42        }
43    }
44}
45
46/// Call arc between two functions
47#[derive(Debug, Clone)]
48pub struct ProfArc {
49    pub from: String,
50    pub to: String,
51    pub calls: u64,
52    pub total_time: f64,
53    pub self_time: f64,
54}
55
56impl ProfArc {
57    pub fn new(from: &str, to: &str) -> Self {
58        Self {
59            from: from.to_string(),
60            to: to.to_string(),
61            calls: 0,
62            total_time: 0.0,
63            self_time: 0.0,
64        }
65    }
66}
67
68/// Stack frame for tracking function calls
69#[derive(Debug)]
70struct StackFrame {
71    func_name: String,
72    start_time: Instant,
73}
74
75/// Profiler state
76#[derive(Debug, Default)]
77pub struct Profiler {
78    functions: HashMap<String, ProfFunc>,
79    arcs: HashMap<(String, String), ProfArc>,
80    stack: Vec<StackFrame>,
81    enabled: bool,
82}
83
84impl Profiler {
85    pub fn new() -> Self {
86        Self {
87            functions: HashMap::new(),
88            arcs: HashMap::new(),
89            stack: Vec::new(),
90            enabled: true,
91        }
92    }
93
94    /// Start profiling a function call
95    pub fn enter_function(&mut self, name: &str) {
96        if !self.enabled {
97            return;
98        }
99
100        let func = self
101            .functions
102            .entry(name.to_string())
103            .or_insert_with(|| ProfFunc::new(name));
104        func.calls += 1;
105
106        if let Some(caller) = self.stack.last() {
107            let key = (caller.func_name.clone(), name.to_string());
108            let arc = self
109                .arcs
110                .entry(key)
111                .or_insert_with(|| ProfArc::new(&caller.func_name, name));
112            arc.calls += 1;
113        }
114
115        self.stack.push(StackFrame {
116            func_name: name.to_string(),
117            start_time: Instant::now(),
118        });
119    }
120
121    /// End profiling a function call
122    pub fn exit_function(&mut self, name: &str) {
123        if !self.enabled {
124            return;
125        }
126
127        if let Some(frame) = self.stack.pop() {
128            if frame.func_name != name {
129                self.stack.push(frame);
130                return;
131            }
132
133            let elapsed = frame.start_time.elapsed().as_secs_f64() * 1000.0;
134
135            if let Some(func) = self.functions.get_mut(name) {
136                func.self_time += elapsed;
137
138                let is_recursive = self.stack.iter().any(|f| f.func_name == name);
139                if !is_recursive {
140                    func.total_time += elapsed;
141                }
142            }
143
144            if let Some(caller) = self.stack.last() {
145                let key = (caller.func_name.clone(), name.to_string());
146                if let Some(arc) = self.arcs.get_mut(&key) {
147                    arc.self_time += elapsed;
148                    arc.total_time += elapsed;
149                }
150            }
151        }
152    }
153
154    /// Clear all profiling data
155    pub fn clear(&mut self) {
156        self.functions.clear();
157        self.arcs.clear();
158        self.stack.clear();
159    }
160
161    /// Enable profiling
162    pub fn enable(&mut self) {
163        self.enabled = true;
164    }
165
166    /// Disable profiling
167    pub fn disable(&mut self) {
168        self.enabled = false;
169    }
170
171    /// Check if profiling is enabled
172    pub fn is_enabled(&self) -> bool {
173        self.enabled
174    }
175
176    /// Get total time across all functions
177    pub fn total_time(&self) -> f64 {
178        self.functions.values().map(|f| f.self_time).sum()
179    }
180
181    /// Get functions sorted by self time (descending)
182    pub fn functions_by_self(&self) -> Vec<&ProfFunc> {
183        let mut funcs: Vec<_> = self.functions.values().collect();
184        funcs.sort_by(|a, b| b.self_time.partial_cmp(&a.self_time).unwrap());
185        funcs
186    }
187
188    /// Get functions sorted by total time (descending)
189    pub fn functions_by_total(&self) -> Vec<&ProfFunc> {
190        let mut funcs: Vec<_> = self.functions.values().collect();
191        funcs.sort_by(|a, b| b.total_time.partial_cmp(&a.total_time).unwrap());
192        funcs
193    }
194
195    /// Get arcs sorted by time (descending)
196    pub fn arcs_by_time(&self) -> Vec<&ProfArc> {
197        let mut arcs: Vec<_> = self.arcs.values().collect();
198        arcs.sort_by(|a, b| b.total_time.partial_cmp(&a.total_time).unwrap());
199        arcs
200    }
201
202    /// Generate profile report
203    pub fn report(&mut self) -> String {
204        let mut output = String::new();
205        let total = self.total_time();
206
207        if total == 0.0 {
208            return "No profiling data collected.\n".to_string();
209        }
210
211        output.push_str(
212            "num  calls                time                       self            name\n",
213        );
214        output.push_str(
215            "-----------------------------------------------------------------------------------\n",
216        );
217
218        let mut funcs_by_self: Vec<_> = self.functions.values_mut().collect();
219        funcs_by_self.sort_by(|a, b| b.self_time.partial_cmp(&a.self_time).unwrap());
220
221        for (i, func) in funcs_by_self.iter_mut().enumerate() {
222            func.num = i + 1;
223            let time_pct = (func.total_time / total) * 100.0;
224            let self_pct = (func.self_time / total) * 100.0;
225
226            output.push_str(&format!(
227                "{:2}) {:4}       {:8.2} {:8.2}  {:6.2}%  {:8.2} {:8.2}  {:6.2}%  {}\n",
228                func.num,
229                func.calls,
230                func.total_time,
231                func.avg_time(),
232                time_pct,
233                func.self_time,
234                func.avg_self(),
235                self_pct,
236                func.name
237            ));
238        }
239
240        let func_nums: HashMap<String, usize> = self
241            .functions
242            .iter()
243            .map(|(name, f)| (name.clone(), f.num))
244            .collect();
245
246        let mut funcs_by_total: Vec<_> = self.functions.values().collect();
247        funcs_by_total.sort_by(|a, b| b.total_time.partial_cmp(&a.total_time).unwrap());
248
249        for func in funcs_by_total {
250            output.push_str("\n-----------------------------------------------------------------------------------\n\n");
251
252            let arcs: Vec<_> = self.arcs.values().filter(|a| a.to == func.name).collect();
253
254            for arc in &arcs {
255                let from_num = func_nums.get(&arc.from).copied().unwrap_or(0);
256                let time_pct = (arc.total_time / total) * 100.0;
257                output.push_str(&format!(
258                    "    {:4}/{:<4}  {:8.2} {:8.2}  {:6.2}%  {:8.2} {:8.2}             {} [{}]\n",
259                    arc.calls,
260                    func.calls,
261                    arc.total_time,
262                    if arc.calls > 0 {
263                        arc.total_time / arc.calls as f64
264                    } else {
265                        0.0
266                    },
267                    time_pct,
268                    arc.self_time,
269                    if arc.calls > 0 {
270                        arc.self_time / arc.calls as f64
271                    } else {
272                        0.0
273                    },
274                    arc.from,
275                    from_num
276                ));
277            }
278
279            let time_pct = (func.total_time / total) * 100.0;
280            let self_pct = (func.self_time / total) * 100.0;
281            output.push_str(&format!(
282                "{:2}) {:4}       {:8.2} {:8.2}  {:6.2}%  {:8.2} {:8.2}  {:6.2}%  {}\n",
283                func.num,
284                func.calls,
285                func.total_time,
286                func.avg_time(),
287                time_pct,
288                func.self_time,
289                func.avg_self(),
290                self_pct,
291                func.name
292            ));
293
294            let callee_arcs: Vec<_> = self.arcs.values().filter(|a| a.from == func.name).collect();
295
296            for arc in callee_arcs.iter().rev() {
297                let to_num = func_nums.get(&arc.to).copied().unwrap_or(0);
298                let to_calls = self.functions.get(&arc.to).map(|f| f.calls).unwrap_or(0);
299                let time_pct = (arc.total_time / total) * 100.0;
300                output.push_str(&format!(
301                    "    {:4}/{:<4}  {:8.2} {:8.2}  {:6.2}%  {:8.2} {:8.2}             {} [{}]\n",
302                    arc.calls,
303                    to_calls,
304                    arc.total_time,
305                    if arc.calls > 0 {
306                        arc.total_time / arc.calls as f64
307                    } else {
308                        0.0
309                    },
310                    time_pct,
311                    arc.self_time,
312                    if arc.calls > 0 {
313                        arc.self_time / arc.calls as f64
314                    } else {
315                        0.0
316                    },
317                    arc.to,
318                    to_num
319                ));
320            }
321        }
322
323        output
324    }
325}
326
327/// Options for zprof builtin
328#[derive(Debug, Default)]
329pub struct ZprofOptions {
330    pub clear: bool,
331}
332
333/// Execute zprof builtin
334pub fn builtin_zprof(profiler: &mut Profiler, options: &ZprofOptions) -> (i32, String) {
335    if options.clear {
336        profiler.clear();
337        (0, String::new())
338    } else {
339        (0, profiler.report())
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use std::thread;
347    use std::time::Duration;
348
349    #[test]
350    fn test_prof_func_new() {
351        let f = ProfFunc::new("test_func");
352        assert_eq!(f.name, "test_func");
353        assert_eq!(f.calls, 0);
354        assert_eq!(f.total_time, 0.0);
355        assert_eq!(f.self_time, 0.0);
356    }
357
358    #[test]
359    fn test_prof_func_avg() {
360        let mut f = ProfFunc::new("test");
361        f.calls = 4;
362        f.total_time = 100.0;
363        f.self_time = 80.0;
364
365        assert_eq!(f.avg_time(), 25.0);
366        assert_eq!(f.avg_self(), 20.0);
367    }
368
369    #[test]
370    fn test_prof_arc_new() {
371        let a = ProfArc::new("caller", "callee");
372        assert_eq!(a.from, "caller");
373        assert_eq!(a.to, "callee");
374        assert_eq!(a.calls, 0);
375    }
376
377    #[test]
378    fn test_profiler_new() {
379        let p = Profiler::new();
380        assert!(p.is_enabled());
381        assert!(p.functions.is_empty());
382        assert!(p.arcs.is_empty());
383    }
384
385    #[test]
386    fn test_profiler_enter_exit() {
387        let mut p = Profiler::new();
388
389        p.enter_function("func1");
390        thread::sleep(Duration::from_millis(10));
391        p.exit_function("func1");
392
393        assert_eq!(p.functions.len(), 1);
394        let func = p.functions.get("func1").unwrap();
395        assert_eq!(func.calls, 1);
396        assert!(func.self_time > 0.0);
397    }
398
399    #[test]
400    fn test_profiler_nested_calls() {
401        let mut p = Profiler::new();
402
403        p.enter_function("outer");
404        p.enter_function("inner");
405        thread::sleep(Duration::from_millis(5));
406        p.exit_function("inner");
407        p.exit_function("outer");
408
409        assert_eq!(p.functions.len(), 2);
410        assert_eq!(p.arcs.len(), 1);
411
412        let arc = p
413            .arcs
414            .get(&("outer".to_string(), "inner".to_string()))
415            .unwrap();
416        assert_eq!(arc.calls, 1);
417    }
418
419    #[test]
420    fn test_profiler_clear() {
421        let mut p = Profiler::new();
422        p.enter_function("test");
423        p.exit_function("test");
424
425        assert!(!p.functions.is_empty());
426        p.clear();
427        assert!(p.functions.is_empty());
428        assert!(p.arcs.is_empty());
429    }
430
431    #[test]
432    fn test_profiler_disable() {
433        let mut p = Profiler::new();
434        p.disable();
435
436        p.enter_function("test");
437        p.exit_function("test");
438
439        assert!(p.functions.is_empty());
440    }
441
442    #[test]
443    fn test_builtin_zprof_clear() {
444        let mut p = Profiler::new();
445        p.enter_function("test");
446        p.exit_function("test");
447
448        let options = ZprofOptions { clear: true };
449        let (status, _) = builtin_zprof(&mut p, &options);
450
451        assert_eq!(status, 0);
452        assert!(p.functions.is_empty());
453    }
454
455    #[test]
456    fn test_builtin_zprof_report() {
457        let mut p = Profiler::new();
458
459        let options = ZprofOptions { clear: false };
460        let (status, output) = builtin_zprof(&mut p, &options);
461
462        assert_eq!(status, 0);
463        assert!(output.contains("No profiling data"));
464    }
465}