Skip to main content

tract_libcli/
lib.rs

1#![allow(clippy::collapsible_if)]
2#[macro_use]
3extern crate log;
4
5pub mod annotations;
6pub mod display_params;
7pub mod draw;
8pub mod export;
9pub mod model;
10pub mod profile;
11pub mod tensor;
12pub mod terminal;
13pub mod time;
14
15use tract_core::internal::*;
16#[allow(unused_imports)]
17#[cfg(any(target_os = "linux", target_os = "windows"))]
18use tract_cuda::utils::ensure_cuda_runtime_dependencies;
19
20pub fn capture_gpu_trace<F>(matches: &clap::ArgMatches, func: F) -> TractResult<()>
21where
22    F: FnOnce() -> TractResult<()>,
23{
24    if matches.contains_id("metal-gpu-trace")
25        && matches.get_one::<String>("metal-gpu-trace").is_some()
26    {
27        #[cfg(any(target_os = "macos", target_os = "ios"))]
28        {
29            let gpu_trace_path =
30                std::path::Path::new(matches.get_one::<String>("metal-gpu-trace").unwrap())
31                    .to_path_buf();
32            ensure!(gpu_trace_path.is_absolute(), "Metal GPU trace file has to be absolute");
33            ensure!(
34                !gpu_trace_path.exists(),
35                format!("Given Metal GPU trace file {:?} already exists.", gpu_trace_path)
36            );
37
38            log::info!("Capturing Metal GPU trace at : {gpu_trace_path:?}");
39            tract_metal::METAL_STREAM.with_borrow(move |stream| {
40                stream.capture_trace(gpu_trace_path, move |_stream| func())
41            })
42        }
43        #[cfg(not(any(target_os = "macos", target_os = "ios")))]
44        {
45            bail!("`--metal-gpu-trace` present but it is only available on MacOS and iOS")
46        }
47    } else if matches.contains_id("cuda-gpu-trace")
48        && matches.get_one::<String>("cuda-gpu-trace").is_some()
49    {
50        #[cfg(any(target_os = "linux", target_os = "windows"))]
51        {
52            ensure_cuda_runtime_dependencies(
53                "`--cuda-gpu-trace` present but no CUDA installation has been found",
54            )?;
55            let _prof = cudarc::driver::safe::Profiler::new()?;
56            func()
57        }
58        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
59        {
60            bail!("`--cuda-gpu-trace` present but it is only available on Linux and Windows")
61        }
62    } else {
63        func()
64    }
65}