DebugOptions

Struct DebugOptions 

Source
pub struct DebugOptions {
Show 201 fields pub xla_cpu_enable_concurrency_optimized_scheduler: bool, pub xla_cpu_enable_fast_math: bool, pub xla_cpu_enable_fast_min_max: bool, pub xla_cpu_fast_math_honor_division: bool, pub xla_cpu_fast_math_honor_functions: bool, pub xla_cpu_fast_math_honor_infs: bool, pub xla_cpu_fast_math_honor_nans: bool, pub xla_cpu_use_thunk_runtime: bool, pub xla_cpu_parallel_codegen_split_count: i32, pub xla_cpu_prefer_vector_width: i32, pub xla_gpu_experimental_autotune_cache_mode: i32, pub xla_gpu_experimental_enable_triton_softmax_priority_fusion: bool, pub xla_gpu_unsupported_enable_triton_gemm: bool, pub xla_hlo_graph_addresses: bool, pub xla_hlo_profile: bool, pub xla_disable_hlo_passes: Vec<String>, pub xla_enable_hlo_passes_only: Vec<String>, pub xla_disable_all_hlo_passes: bool, pub xla_backend_optimization_level: i32, pub xla_embed_ir_in_executable: bool, pub xla_eliminate_hlo_implicit_broadcast: bool, pub xla_cpu_multi_thread_eigen: bool, pub xla_gpu_cuda_data_dir: String, pub xla_gpu_ftz: bool, pub xla_llvm_enable_alias_scope_metadata: bool, pub xla_llvm_enable_noalias_metadata: bool, pub xla_llvm_enable_invariant_load_metadata: bool, pub xla_llvm_disable_expensive_passes: bool, pub xla_test_all_output_layouts: bool, pub xla_test_all_input_layouts: bool, pub xla_hlo_graph_sharding_color: bool, pub xla_cpu_use_mkl_dnn: bool, pub xla_gpu_enable_fast_min_max: bool, pub xla_allow_excess_precision: bool, pub xla_gpu_crash_on_verification_failures: bool, pub xla_gpu_autotune_level: i32, pub xla_force_host_platform_device_count: i32, pub xla_gpu_disable_gpuasm_optimizations: bool, pub xla_gpu_shape_checks: i32, pub xla_hlo_evaluator_use_fast_path: bool, pub xla_allow_scalar_index_dynamic_ops: bool, pub xla_step_marker_location: i32, pub xla_dump_to: String, pub xla_dump_hlo_module_re: String, pub xla_dump_hlo_pass_re: String, pub xla_dump_hlo_as_text: bool, pub xla_dump_hlo_as_proto: bool, pub xla_dump_hlo_as_dot: bool, pub xla_dump_hlo_as_url: bool, pub xla_dump_hlo_as_html: bool, pub xla_dump_fusion_visualization: bool, pub xla_dump_hlo_snapshots: bool, pub xla_dump_include_timestamp: bool, pub xla_dump_max_hlo_modules: i32, pub xla_dump_module_metadata: bool, pub xla_dump_compress_protos: bool, pub xla_dump_hlo_as_long_text: bool, pub xla_gpu_force_conv_nchw: bool, pub xla_gpu_force_conv_nhwc: bool, pub xla_gpu_ptx_file: Vec<String>, pub xla_gpu_dump_llvmir: bool, pub xla_dump_enable_mlir_pretty_form: bool, pub xla_gpu_algorithm_denylist_path: String, pub xla_tpu_detect_nan: bool, pub xla_tpu_detect_inf: bool, pub xla_cpu_enable_xprof_traceme: bool, pub xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found: bool, pub xla_gpu_asm_extra_flags: String, pub xla_multiheap_size_constraint_per_heap: i32, pub xla_detailed_logging: bool, pub xla_enable_dumping: bool, pub xla_gpu_force_compilation_parallelism: i32, pub xla_gpu_enable_llvm_module_compilation_parallelism: bool, pub xla_gpu_deterministic_ops: bool, pub xla_gpu_llvm_ir_file: Vec<String>, pub xla_gpu_disable_async_collectives: Vec<i32>, pub xla_gpu_all_reduce_combine_threshold_bytes: i64, pub xla_gpu_all_gather_combine_threshold_bytes: i64, pub xla_gpu_reduce_scatter_combine_threshold_bytes: i64, pub xla_gpu_enable_all_gather_combine_by_dim: bool, pub xla_gpu_enable_reduce_scatter_combine_by_dim: bool, pub xla_gpu_enable_reassociation_for_converted_ar: bool, pub xla_gpu_all_reduce_blueconnect_num_devices_per_host: i32, pub xla_gpu_enable_while_loop_reduce_scatter_code_motion: bool, pub xla_gpu_collective_inflation_factor: i32, pub xla_llvm_force_inline_before_split: bool, pub xla_gpu_enable_cudnn_frontend: bool, pub xla_gpu_enable_cudnn_fmha: bool, pub xla_gpu_fused_attention_use_cudnn_rng: bool, pub xla_gpu_enable_cudnn_layer_norm: bool, pub xla_dump_disable_metadata: bool, pub xla_dump_hlo_pipeline_re: String, pub xla_gpu_strict_conv_algorithm_picker: bool, pub xla_gpu_enable_custom_fusions: bool, pub xla_gpu_enable_custom_fusions_re: String, pub xla_gpu_enable_dynamic_slice_fusion: bool, pub xla_gpu_nccl_termination_timeout_seconds: i64, pub xla_gpu_enable_shared_constants: bool, pub xla_gpu_enable_cublaslt: bool, pub xla_gpu_enable_command_buffer: Vec<i32>, pub xla_gpu_graph_min_graph_size: i32, pub xla_gpu_graph_enable_concurrent_region: bool, pub xla_gpu_redzone_scratch_max_megabytes: i64, pub xla_gpu_redzone_padding_bytes: i64, pub xla_cpu_use_acl: bool, pub xla_cpu_strict_dot_conv_math: bool, pub xla_gpu_use_runtime_fusion: bool, pub xla_dump_latency_hiding_schedule: bool, pub xla_cpu_enable_mlir_tiling_and_fusion: bool, pub xla_cpu_enable_custom_matmul_tiling: bool, pub xla_cpu_matmul_tiling_m_dim: i64, pub xla_cpu_matmul_tiling_n_dim: i64, pub xla_cpu_matmul_tiling_k_dim: i64, pub xla_cpu_enable_mlir_fusion_outlining: bool, pub xla_cpu_enable_experimental_deallocation: bool, pub xla_gpu_enable_latency_hiding_scheduler: bool, pub xla_gpu_enable_highest_priority_async_stream: bool, pub xla_gpu_enable_analytical_latency_estimator: bool, pub xla_gpu_lhs_enable_gpu_async_tracker: bool, pub xla_gpu_pgle_profile_file_or_directory_path: String, pub xla_gpu_memory_limit_slop_factor: i32, pub xla_gpu_enable_pipelined_collectives: bool, pub xla_gpu_enable_pipelined_all_reduce: bool, pub xla_gpu_enable_pipelined_all_gather: bool, pub xla_gpu_enable_pipelined_reduce_scatter: bool, pub xla_gpu_enable_pipelined_p2p: bool, pub xla_gpu_run_post_layout_collective_pipeliner: bool, pub xla_gpu_collective_permute_decomposer_threshold: i64, pub xla_partitioning_algorithm: i32, pub xla_gpu_enable_triton_gemm: bool, pub xla_gpu_enable_cudnn_int8x32_convolution_reordering: bool, pub xla_gpu_triton_gemm_any: bool, pub xla_gpu_exhaustive_tiling_search: bool, pub xla_gpu_enable_priority_fusion: bool, pub xla_gpu_dump_autotune_results_to: String, pub xla_gpu_load_autotune_results_from: String, pub xla_gpu_target_config_filename: String, pub xla_gpu_auto_spmd_partitioning_memory_budget_gb: i32, pub xla_gpu_auto_spmd_partitioning_memory_budget_ratio: f32, pub xla_gpu_triton_gemm_disable_reduced_precision_reduction: bool, pub xla_gpu_triton_fusion_level: i32, pub xla_gpu_dump_autotuned_gemm_fusions: bool, pub xla_gpu_override_gemm_autotuner: String, pub xla_gpu_copy_insertion_use_region_analysis: bool, pub xla_gpu_collect_cost_model_stats: bool, pub xla_gpu_enable_split_k_autotuning: bool, pub xla_gpu_enable_reduction_epilogue_fusion: bool, pub xla_gpu_enable_nccl_clique_optimization: bool, pub xla_gpu_mock_custom_calls: bool, pub xla_gpu_cublas_fallback: bool, pub xla_gpu_enable_while_loop_double_buffering: bool, pub xla_gpu_enable_while_loop_unrolling: i32, pub xla_gpu_ensure_minor_dot_contraction_dims: bool, pub xla_gpu_filter_kernels_spilling_registers_on_autotuning: bool, pub xla_debug_buffer_assignment_show_max: i64, pub xla_gpu_llvm_verification_level: i32, pub xla_gpu_enable_cub_radix_sort: bool, pub xla_gpu_threshold_for_windowed_einsum_mib: i64, pub xla_gpu_enable_triton_hopper: bool, pub xla_gpu_enable_nccl_user_buffers: bool, pub xla_gpu_enable_nccl_comm_splitting: bool, pub xla_gpu_enable_nccl_per_stream_comms: bool, pub xla_gpu_enable_libnvptxcompiler: bool, pub xla_gpu_enable_dot_strength_reduction: bool, pub xla_gpu_multi_streamed_windowed_einsum: bool, pub xla_gpu_enable_bf16_6way_gemm: bool, pub xla_gpu_enable_bf16_3way_gemm: bool, pub xla_gpu_nccl_collective_max_nchannels: i64, pub xla_gpu_nccl_p2p_max_nchannels: i64, pub xla_gpu_mlir_emitter_level: i64, pub xla_gpu_gemm_rewrite_size_threshold: i64, pub xla_gpu_require_complete_aot_autotune_results: bool, pub xla_gpu_cudnn_gemm_fusion_level: i32, pub xla_gpu_use_memcpy_local_p2p: bool, pub xla_gpu_autotune_max_solutions: i64, pub xla_dump_large_constants: bool, pub xla_gpu_verify_triton_fusion_numerics: bool, pub xla_gpu_dump_autotune_logs_to: String, pub xla_reduce_window_rewrite_base_length: i64, pub xla_gpu_enable_host_memory_offloading: bool, pub xla_gpu_exclude_nondeterministic_ops: bool, pub xla_gpu_nccl_terminate_on_error: bool, pub xla_gpu_shard_autotuning: bool, pub xla_gpu_enable_approx_costly_collectives: bool, pub xla_gpu_kernel_cache_file: String, pub xla_gpu_unsafe_pipelined_loop_annotator: bool, pub xla_gpu_per_fusion_autotune_cache_dir: String, pub xla_cmd_buffer_trace_cache_size: i64, pub xla_gpu_temp_buffer_use_separate_color: bool, pub legacy_command_buffer_custom_call_targets: Vec<String>, pub xla_syntax_sugar_async_ops: bool, pub xla_gpu_autotune_gemm_rtol: f32, pub xla_enable_command_buffers_during_profiling: bool, pub xla_gpu_cudnn_gemm_max_plans: i32, pub xla_gpu_enable_libnvjitlink: bool, pub xla_gpu_enable_triton_gemm_int4: bool, pub xla_gpu_async_dot: bool, pub xla_gpu_enable_pgle_accuracy_checker: bool, pub xla_gpu_executable_warn_stuck_timeout_seconds: i32, pub xla_gpu_executable_terminate_timeout_seconds: i32, pub xla_backend_extra_options: HashMap<String, String>,
}
Expand description

Debugging options for XLA. These options may change at any time - there are no guarantees about backward or forward compatibility for these fields.

Debug options naming and organization:

  1. Backend-agnostic options: xla_$flag_name - go first, and sorted alphabetically by the flag name.

  2. Backend-specific options: xla_$backend_$flag_name - must be in the corresponding backend section, and sorted alphabetically by the flag name.

–––––––––––––––––––––––––––––––––––––// XLA backend-agnostic options. –––––––––––––––––––––––––––––––––––––// go/keep-sorted start

Fields§

§xla_cpu_enable_concurrency_optimized_scheduler: bool

go/keep-sorted start newline_separated=yes

When true, XLA:CPU uses HLO module scheduler that is optimized for extracting concurrency at the cost of extra memory: we extend the live ranges of temporaries to allow XLA runtime to schedule independent operations in parallel on separate threads.

§xla_cpu_enable_fast_math: bool

When true, “unsafe” mathematical optimizations are enabled. These transformations include but are not limited to:

  • Reducing the precision of operations (e.g. using an approximate sin function, or transforming x/y into x * (1/y)).
  • Assuming that operations never produce or consume NaN or +/- Inf (this behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
  • Assuming that +0 and -0 are indistinguishable.
§xla_cpu_enable_fast_min_max: bool

When false we lower the Minimum and Maximum hlos in the CPU backend such that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag this is false we always propagate NaNs through Min and Max.

Note, this does not correspond to the exact same behavior as the gpu flag below!

§xla_cpu_fast_math_honor_division: bool

When xla_cpu_enable_fast_math is true then this controls whether we forbid to use the reciprocal of an argument instead of division. Ignored when xla_cpu_enable_fast_math is false.

§xla_cpu_fast_math_honor_functions: bool

When xla_cpu_enable_fast_math is true then this controls whether we forbid to approximate calculations for functions. Ignored when xla_cpu_enable_fast_math is false.

§xla_cpu_fast_math_honor_infs: bool

When xla_cpu_enable_fast_math is true then this controls whether we allow operations to produce infinites. Ignored when xla_cpu_enable_fast_math is false.

§xla_cpu_fast_math_honor_nans: bool

When xla_cpu_enable_fast_math is true then this controls whether we allow operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is false.

§xla_cpu_use_thunk_runtime: bool

When true, XLA:CPU uses the thunk runtime to execute compiled program.

§xla_cpu_parallel_codegen_split_count: i32

The number of parts to split the LLVM module into before codegen. This allows XLA to compile all parts in parallel, and resolve kernel symbols from different dynamic libraries.

§xla_cpu_prefer_vector_width: i32

A prefer-vector-width value that is passed to the LLVM backend. Default value is 256 (AVX2 on x86 platforms).

§xla_gpu_experimental_autotune_cache_mode: i32

Specifies the behavior of per kernel autotuning cache.

§xla_gpu_experimental_enable_triton_softmax_priority_fusion: bool

Gates the experimental feature coupling the Triton Softmax pattern matcher with priority fusion.

§xla_gpu_unsupported_enable_triton_gemm: bool

Internal debug/testing flag to switch Triton GEMM fusions on or off.

§xla_hlo_graph_addresses: bool

Show addresses of HLO ops in graph dump.

§xla_hlo_profile: bool

Instrument the computation to collect per-HLO cycle counts.

§xla_disable_hlo_passes: Vec<String>

List of HLO passes to disable/enable. These names must exactly match the pass names as specified by the HloPassInterface::name() method.

At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must be empty.

§xla_enable_hlo_passes_only: Vec<String>§xla_disable_all_hlo_passes: bool

Disables all HLO passes. Notes that some passes are necessary for correctness and the invariants that must be satisfied by “fully optimized” HLO are different for different devices and may change over time. The only “guarantee”, such as it is, is that if you compile XLA and dump the optimized HLO for some graph, you should be able to run it again on the same device with the same build of XLA.

§xla_backend_optimization_level: i32

Numerical optimization level for the XLA compiler backend; the specific interpretation of this value is left to the backends.

§xla_embed_ir_in_executable: bool

Embed the compiler IR as a string in the executable.

§xla_eliminate_hlo_implicit_broadcast: bool

Eliminate implicit broadcasts when lowering user computations to HLO instructions; use explicit broadcast instead.

§xla_cpu_multi_thread_eigen: bool

When generating calls to Eigen in the CPU backend, use multi-threaded Eigen mode.

§xla_gpu_cuda_data_dir: String

Path to directory with cuda/ptx tools and libraries.

§xla_gpu_ftz: bool

Enable flush-to-zero semantics in the GPU backend.

§xla_llvm_enable_alias_scope_metadata: bool

If true, in LLVM-based backends, emit !alias.scope metadata in generated IR.

§xla_llvm_enable_noalias_metadata: bool

If true, in LLVM-based backends, emit !noalias metadata in the generated IR.

§xla_llvm_enable_invariant_load_metadata: bool

If true, in LLVM-based backends, emit !invariant.load metadata in the generated IR.

§xla_llvm_disable_expensive_passes: bool

If true, a set of expensive LLVM optimization passes will not be run.

§xla_test_all_output_layouts: bool

This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the computation will run n! times with all permunations of layouts for the output shape in rank n. For example, with a 3D shape, all permutations of the set {0, 1, 2} are tried.

§xla_test_all_input_layouts: bool

This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the computation will run for all permunations of layouts of all input arguments. For example, with 2 input arguments in 2D and 4D shapes, the computation will run 2! * 4! times.

§xla_hlo_graph_sharding_color: bool

Assign colors based on sharding information when generating the Graphviz HLO graph.

§xla_cpu_use_mkl_dnn: bool

Generate calls to MKL-DNN in the CPU backend.

§xla_gpu_enable_fast_min_max: bool

When true we lower the Minimum and Maximum hlos in the GPU backend such that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag this is true we don’t propagate NaNs through Min and Max.

Note, this does not correspond to the exact same behavior as the cpu flag above!

§xla_allow_excess_precision: bool

Allows xla to increase the output precision of floating point operations and all floating-point conversions to be simplified, including those that affect the numerics. The FloatNormalization pass inserts many f32 -> bf16 -> f32 conversion pairs. These are not removed by the AlgebraicSimplifier, as that will only simplify conversions that are no-ops, e.g. bf16 -> f32 -> bf16. Removing these improves accuracy.

§xla_gpu_crash_on_verification_failures: bool

Crashes the program when any kind of verification fails, instead of just logging the failures. One example is cross checking of convolution results among different algorithms.

§xla_gpu_autotune_level: i32

0: Disable gemm and convolution autotuning. 1: Enable autotuning, but disable correctness checking. 2: Also set output buffers to random numbers during autotuning. 3: Also reset output buffers to random numbers after autotuning each algorithm. 4+: Also check for correct outputs and for out-of-bounds reads/writes.

Default: 4.

§xla_force_host_platform_device_count: i32

Force the host platform to pretend that there are these many host “devices”. All these devices are backed by the same threadpool. Defaults to 1.

Setting this to anything other than 1 can increase overhead from context switching but we let the user override this behavior to help run tests on the host that run models in parallel across multiple devices.

§xla_gpu_disable_gpuasm_optimizations: bool

If set to true XLA:GPU invokes ptxas with -O0 (default is -O3).

§xla_gpu_shape_checks: i32§xla_hlo_evaluator_use_fast_path: bool

Enable fast math with eigen in the HLO evaluator.

§xla_allow_scalar_index_dynamic_ops: bool

Temporary option to allow support for both the R1 and the scalar index versions of DynamicSlice and DynamicUpdateSlice. Only used for testing.

§xla_step_marker_location: i32

Option to emit a target-specific marker to indicate the start of a training step. The location of the marker (if any) is determined by the option value.

§xla_dump_to: String

Directory to dump into.

§xla_dump_hlo_module_re: String

If specified, will only dump modules which match this regexp.

§xla_dump_hlo_pass_re: String

If this flag is specified, will also dump HLO before and after passes that match this regular expression. Set to .* to dump before/after all passes.

§xla_dump_hlo_as_text: bool

Specifies the format that HLO is dumped in. Multiple of these may be specified.

§xla_dump_hlo_as_proto: bool§xla_dump_hlo_as_dot: bool§xla_dump_hlo_as_url: bool§xla_dump_hlo_as_html: bool

Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)

§xla_dump_fusion_visualization: bool

Dump the visualization of the fusion progress.

§xla_dump_hlo_snapshots: bool

If true, every time an HLO module is run, we will dump an HloSnapshot (essentially, a serialized module plus its inputs) to the –xla_dump_to directory.

§xla_dump_include_timestamp: bool

Include a timestamp in the dumped filenames.

§xla_dump_max_hlo_modules: i32

Max number of hlo module dumps in a directory. Set to < 0 for unbounded.

§xla_dump_module_metadata: bool

Dump HloModuleMetadata as a text proto for each HLO module.

§xla_dump_compress_protos: bool

GZip-compress protos dumped via –xla_dump_hlo_as_proto.

§xla_dump_hlo_as_long_text: bool

Dump HLO in long text format. Ignored unless xla_dump_hlo_as_text is true.

§xla_gpu_force_conv_nchw: bool

Overrides for XLA GPU’s convolution layout heuristic.

§xla_gpu_force_conv_nhwc: bool§xla_gpu_ptx_file: Vec<String>

Paths to files with ptx code.

§xla_gpu_dump_llvmir: bool

Whether to dump llvm ir when compiling to ptx.

§xla_dump_enable_mlir_pretty_form: bool

Whether to dump mlir using pretty print form.

§xla_gpu_algorithm_denylist_path: String

Denylist for cuDNN convolutions.

§xla_tpu_detect_nan: bool

Debug options that trigger execution errors when NaN or Inf are detected.

§xla_tpu_detect_inf: bool§xla_cpu_enable_xprof_traceme: bool

True if TraceMe annotations are enabled for XLA:CPU.

§xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found: bool

It is usually preferable to not fallback to the driver; it can consume more memory, or have bugs.

§xla_gpu_asm_extra_flags: String

Extra parameters to pass the GPU assembler.

§xla_multiheap_size_constraint_per_heap: i32

Per-heap size constraint. New heaps will be created if per-heap max size is reached.

§xla_detailed_logging: bool

Enable detailed logging into vlog. If this is disabled, no compilation summary will be printed in the end of computation.

§xla_enable_dumping: bool

Enable HLO dumping. If this is disabled, no HLO modules will be dumped.

§xla_gpu_force_compilation_parallelism: i32

Overrides normal multi-threaded compilation setting to use this many threads. Setting to 0 (the default value) means no enforcement.

§xla_gpu_enable_llvm_module_compilation_parallelism: bool§xla_gpu_deterministic_ops: bool

Guarantees run-to-run determinism. This flag implies –xla_gpu_exclude_nondeterministic_ops and in addition disables autotuning.

§xla_gpu_llvm_ir_file: Vec<String>

Paths to files with LLVM code.

§xla_gpu_disable_async_collectives: Vec<i32>§xla_gpu_all_reduce_combine_threshold_bytes: i64

Size threshold (in bytes) for the GPU collective combiners.

§xla_gpu_all_gather_combine_threshold_bytes: i64§xla_gpu_reduce_scatter_combine_threshold_bytes: i64§xla_gpu_enable_all_gather_combine_by_dim: bool

Combine all-gather/scatter-reduce ops with the same dimension or irrespective of their dimension.

§xla_gpu_enable_reduce_scatter_combine_by_dim: bool§xla_gpu_enable_reassociation_for_converted_ar: bool

Enable allreduce reassociation on allreduces that are converted to a wider type. The resulting allreduce will be promoted to a wider-typed allreduce.

§xla_gpu_all_reduce_blueconnect_num_devices_per_host: i32

Number of devices per host for first stage of BlueConnect decomposition pass. The pass will attempt to decompose all-reduces ops into a ReduceScatter-AllReduce-AllGather sequence, with the initial ReduceScatter being performed over all of the devices in the same host. Set to < 1 to disable all-reduce decomposition.

§xla_gpu_enable_while_loop_reduce_scatter_code_motion: bool

Enable hoisting of reduce-scatter out of while loops.

§xla_gpu_collective_inflation_factor: i32

Inflate collective cost by running each collective multiple times.

§xla_llvm_force_inline_before_split: bool

Whether to force inline before llvm module split to get a more balanced splits for parallel compilation.

§xla_gpu_enable_cudnn_frontend: bool

Whether to use the cuDNN frontend API for convolutions when possible.

§xla_gpu_enable_cudnn_fmha: bool§xla_gpu_fused_attention_use_cudnn_rng: bool§xla_gpu_enable_cudnn_layer_norm: bool

Rewrite layer norm patterns into cuDNN library calls.

§xla_dump_disable_metadata: bool

Disable dumping metadata in HLO dumps.

§xla_dump_hlo_pipeline_re: String

If this flag is specified, will only dump HLO before and after passes in the pass pipeline that matches this regular expression. Default empty value enables dumping in all pipelines.

§xla_gpu_strict_conv_algorithm_picker: bool

If true, abort immediately when conv algorithm picker fails, rather than logging a warning and proceeding with fallback.

§xla_gpu_enable_custom_fusions: bool

If true, XLA will try to pattern match subgraphs of HLO operations into custom fusions registered in the current process (pre-compiled hand written kernels, e.g. various GEMM fusions writtent in CUTLASS).

§xla_gpu_enable_custom_fusions_re: String

A regular expression enabling only a subset of custom fusions. Enabled only if xla_gpu_enable_custom_fusion set to true.

§xla_gpu_enable_dynamic_slice_fusion: bool

Enables address computation fusion to optimize dynamic-slice and dynamic-update-slice operations around library calls.

§xla_gpu_nccl_termination_timeout_seconds: i64

Timeout in seconds before terminating jobs that are stuck in a NCCL Rendezvous. Negative value disables the timeout and will not terminate.

§xla_gpu_enable_shared_constants: bool

Enables shared constants for XLA/GPU. This allows large constants to be shared among multiple GPU executables.

§xla_gpu_enable_cublaslt: bool

Whether to use cuBLASLt for GEMMs on GPUs.

§xla_gpu_enable_command_buffer: Vec<i32>

Determine the types of commands that are recorded into command buffers.

§xla_gpu_graph_min_graph_size: i32

This number determines how many moved instructions like fusion kernels are required for a region to be captured as a function to be launched as a GPU graph.

§xla_gpu_graph_enable_concurrent_region: bool

Identify concurrent regions in GPU graphs and execute them concurrently.

§xla_gpu_redzone_scratch_max_megabytes: i64

Size threshold (in megabytes) for the GPU redzone scratch allocator.

§xla_gpu_redzone_padding_bytes: i64

Amount of padding the redzone allocator will put on one side of each buffer it allocates. (So the buffer’s total size will be increased by 2x this value.)

Higher values make it more likely that we’ll catch an out-of-bounds read or write. Smaller values consume less memory during autotuning. Note that a fused cudnn conv has up to 6 total buffers (4 inputs, 1 output, and 1 scratch), so this can be multiplied by quite a lot.

§xla_cpu_use_acl: bool

Generate calls to Arm Compute Library in the CPU backend.

§xla_cpu_strict_dot_conv_math: bool

By default, XLA:CPU will run fp16 dot/conv as fp32, as this is generally (much) faster on our hardware. Set this flag to disable this behavior.

§xla_gpu_use_runtime_fusion: bool

An option to enable using cuDNN runtime compiled fusion kernels which is available and recommended for Ampere+ GPUs.

§xla_dump_latency_hiding_schedule: bool§xla_cpu_enable_mlir_tiling_and_fusion: bool

By default, MLIR lowering will use Linalg elementwise fusion. If this flag is enabled, the pipeline will use tiling, fusion, peeling, vectorization instead.

§xla_cpu_enable_custom_matmul_tiling: bool

XLA:CPU-Next tiling parameters for matmul.

§xla_cpu_matmul_tiling_m_dim: i64§xla_cpu_matmul_tiling_n_dim: i64§xla_cpu_matmul_tiling_k_dim: i64§xla_cpu_enable_mlir_fusion_outlining: bool§xla_cpu_enable_experimental_deallocation: bool

If set, use the experimental deallocation pass from mlir-hlo.

§xla_gpu_enable_latency_hiding_scheduler: bool§xla_gpu_enable_highest_priority_async_stream: bool§xla_gpu_enable_analytical_latency_estimator: bool§xla_gpu_lhs_enable_gpu_async_tracker: bool§xla_gpu_pgle_profile_file_or_directory_path: String§xla_gpu_memory_limit_slop_factor: i32§xla_gpu_enable_pipelined_collectives: bool§xla_gpu_enable_pipelined_all_reduce: bool§xla_gpu_enable_pipelined_all_gather: bool§xla_gpu_enable_pipelined_reduce_scatter: bool§xla_gpu_enable_pipelined_p2p: bool§xla_gpu_run_post_layout_collective_pipeliner: bool§xla_gpu_collective_permute_decomposer_threshold: i64

The minimum data size in bytes to trigger collective-permute-decomposer transformation.

§xla_partitioning_algorithm: i32

The partitioning algorithm to be used in the PartitionAssignment pass.

§xla_gpu_enable_triton_gemm: bool§xla_gpu_enable_cudnn_int8x32_convolution_reordering: bool§xla_gpu_triton_gemm_any: bool

Creates triton fusion for all supported gemms. To make sure only triton gemm is chosen by the autotuner run with xla_gpu_cublas_fallback set to false.

§xla_gpu_exhaustive_tiling_search: bool§xla_gpu_enable_priority_fusion: bool§xla_gpu_dump_autotune_results_to: String

File to write autotune results to. It will be a binary file unless the name ends with .txt or .textproto. Warning: The results are written at every compilation, possibly multiple times per process. This only works on CUDA.

§xla_gpu_load_autotune_results_from: String

File to load autotune results from. It will be considered a binary file unless the name ends with .txt or .textproto. At most one loading will happen during the lifetime of one process, even if the first one is unsuccessful or different file paths are passed here. This only works on CUDA.

§xla_gpu_target_config_filename: String

Description of the target platform in GpuTargetConfigProto format; if provided, deviceless compilation is assumed, and the current device is ignored.

§xla_gpu_auto_spmd_partitioning_memory_budget_gb: i32

Memory budget in GB per device for AutoSharding.

§xla_gpu_auto_spmd_partitioning_memory_budget_ratio: f32

See the definition of the xla_gpu_auto_spmd_partitioning_memory_budget_ratio flag for the meaning of this field.

§xla_gpu_triton_gemm_disable_reduced_precision_reduction: bool§xla_gpu_triton_fusion_level: i32§xla_gpu_dump_autotuned_gemm_fusions: bool§xla_gpu_override_gemm_autotuner: String§xla_gpu_copy_insertion_use_region_analysis: bool§xla_gpu_collect_cost_model_stats: bool

If true, each fusion instruction will have a cost model runtime estimate in backend config after compilation.

§xla_gpu_enable_split_k_autotuning: bool§xla_gpu_enable_reduction_epilogue_fusion: bool

Whether reduction epilogue fusion is enabled in fusion passes.

§xla_gpu_enable_nccl_clique_optimization: bool

Allow early return when acquiring NCCL cliques.

§xla_gpu_mock_custom_calls: bool

Replace custom calls with noop operations.

§xla_gpu_cublas_fallback: bool

Allow Triton GEMM autotuning to fall back to cuBLAS when that is faster.

§xla_gpu_enable_while_loop_double_buffering: bool

Enable double buffering for loops.

§xla_gpu_enable_while_loop_unrolling: i32

Determine the while loop unrolling scheme.

§xla_gpu_ensure_minor_dot_contraction_dims: bool

Change the layout of the second triton dot operand to be column major. Only works for (bf16 x bf16) -> bf16.

§xla_gpu_filter_kernels_spilling_registers_on_autotuning: bool

Filter out kernels that spill registers during autotuning.

§xla_debug_buffer_assignment_show_max: i64

Maximum number of buffers to print when debugging buffer assignment.

§xla_gpu_llvm_verification_level: i32§xla_gpu_enable_cub_radix_sort: bool

Enable radix sort using CUB.

§xla_gpu_threshold_for_windowed_einsum_mib: i64

Threshold to enable windowed einsum (collective matmul) in MB.

§xla_gpu_enable_triton_hopper: bool

Enables currently disabled features within Triton for Hopper.

§xla_gpu_enable_nccl_user_buffers: bool

Enable NCCL user buffers.

§xla_gpu_enable_nccl_comm_splitting: bool

Enable NCCL communicator splitting.

§xla_gpu_enable_nccl_per_stream_comms: bool

Enable NCCL per stream communicators.

§xla_gpu_enable_libnvptxcompiler: bool

If enabled, uses the libnvptxcompiler library to compile PTX to cuBIN.

§xla_gpu_enable_dot_strength_reduction: bool§xla_gpu_multi_streamed_windowed_einsum: bool

Whether to use multiple compute streams to run windowed einsum.

§xla_gpu_enable_bf16_6way_gemm: bool

If enabled, uses bf16_6way gemm to compute F32 gemm.

§xla_gpu_enable_bf16_3way_gemm: bool

If enabled, uses bf16_3way gemm to compute F32 gemm.

§xla_gpu_nccl_collective_max_nchannels: i64

Specify the maximum number of channels(SMs) NCCL will use for collective operations.

§xla_gpu_nccl_p2p_max_nchannels: i64

Specify the maximum number of channels(SMs) NCCL will use for p2p operations.

§xla_gpu_mlir_emitter_level: i64

Choose the level of mlir emitters that are enabled. Current levels: 0: Disabled. 1: Loop emitter 2: + Loop-like emitters 3: + Transpose 4: + Reduce

§xla_gpu_gemm_rewrite_size_threshold: i64

Threshold to rewrite matmul to cuBLAS or Triton (minimum combined number of elements of both matrices in non-batch dimensions to be considered for a rewrite).

§xla_gpu_require_complete_aot_autotune_results: bool

If true, will require complete AOT autotuning results; in the case of missing AOT result, the model will not be compiled or executed, a NotFound error will be returned.

§xla_gpu_cudnn_gemm_fusion_level: i32

Let GEMM fusion autotuning probe cuDNN as a backend. Current levels: 0: Disabled. 1: Fusions of GEMM, elementwise, transpose/reshape operations. 2: + Broadcasts, slicing. 3: + Nontrivial noncontracting dimension reshapes/transposes.

§xla_gpu_use_memcpy_local_p2p: bool

This instructs the runtime whether to use memcpy for p2p communication when source and target are located within a node(nvlink).

§xla_gpu_autotune_max_solutions: i64

If non-zero, limits the number of solutions to be used by GEMM autotuner. This might be useful if underlying math library returns too many GEMM solutions.

§xla_dump_large_constants: bool

If true, large constants will be printed out when dumping HLOs.

§xla_gpu_verify_triton_fusion_numerics: bool

If true, will verify that the numerical results of Triton fusions match the results of regular emitters.

§xla_gpu_dump_autotune_logs_to: String

File to write autotune logs to. It will stored in txt format.

§xla_reduce_window_rewrite_base_length: i64

Base length to rewrite the reduce window to, no rewrite if set to 0.

§xla_gpu_enable_host_memory_offloading: bool

If true, will enable host memory offloading on a device.

§xla_gpu_exclude_nondeterministic_ops: bool

Excludes non-deterministic ops from compiled executables. Unlike –xla_gpu_deterministic_ops does not disable autotuning - the compilation itself can be non-deterministic. At present, the HLO op SelectAndScatter does not have a deterministic XLA:GPU implementation. Compilation errors out if SelectAndScatter is encountered. Scatter ops can non-deterministic by default; these get converted to a deterministic implementation.

§xla_gpu_nccl_terminate_on_error: bool

If true, Nccl errors will terminate the process.

§xla_gpu_shard_autotuning: bool§xla_gpu_enable_approx_costly_collectives: bool§xla_gpu_kernel_cache_file: String§xla_gpu_unsafe_pipelined_loop_annotator: bool

Recognises rotate-right patterns (slice, slice, concat) within a while loop and labels the while loop as a pipelined while loop. This is an unsafe flag.

§xla_gpu_per_fusion_autotune_cache_dir: String§xla_cmd_buffer_trace_cache_size: i64

The command buffer trace cache size, increasing the cache size may sometimes reduces the chances of doing command buffer tracing for updating command buffer instance.

§xla_gpu_temp_buffer_use_separate_color: bool

Enable this flag will use a separate memory space color for temp buffer, and then will use separate memory allocator to allocate it, as there is no other memory allocation interference, it will allocate temp buffer to some fix address on every iteration, which is good for cuda-graph perf.

§legacy_command_buffer_custom_call_targets: Vec<String>

Custom call targets with legacy registry API (non FFI API), that support recording to command buffer custom command, i.e., custom call target supports cuda-graph capturing for CUDA devices. This flag is read if CUSTOM_CALL command type is recorded into command buffer.

§xla_syntax_sugar_async_ops: bool

This flag is used for controlling HLO dumping and NVTX marker. If turned on, both HLO dumping and NVTX marker will use syntactic sugar wrappers as op names, while the actual op names will be shown if turned off.

Here is an example HLO excerpt with the flag off:

async_computation { param_0 = f32[1,4,8]{1,0,2} parameter(0) ROOT all-to-all.3.1 = f32[1,4,8]{1,0,2} all-to-all(param_0), replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2} } …

all-to-all-start = ((f32[1,4,8]{1,0,2}), f32[1,4,8]{1,0,2}) async-start(bitcast.24.0), calls=async_computation, backend_config={…} all-to-all-done = f32[1,4,8]{1,0,2} async-done(all-to-all-start)

and with the flag on:

all-to-all-start = ((f32[1,4,8]{1,0,2}), f32[1,4,8]{1,0,2}) all-to-all-start(bitcast.24.0), replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, backend_config={…} all-to-all-done = f32[1,4,8]{1,0,2} all-to-all-done(all-to-all-start)

§xla_gpu_autotune_gemm_rtol: f32

Relative precision for comparing different GEMM solutions

§xla_enable_command_buffers_during_profiling: bool

Allow launching command buffers while profiling active. When disabled, execute in op-by-op mode. TODO(b/355487968): Remove this option when validation complete.

§xla_gpu_cudnn_gemm_max_plans: i32

Limit for the number of kernel configurations (plans) to use during autotuning of cuDNN GEMM fusions. The more - the slower the autotuning but potentially higher the performance.

§xla_gpu_enable_libnvjitlink: bool

If enabled, uses the libnvjitlink library for PTX compilation and linking

§xla_gpu_enable_triton_gemm_int4: bool

If enabled, generates triton gemm kernels for int4 inputs.

§xla_gpu_async_dot: bool

If true, XLA will wrap dot operations into async computations in an effort to parallelize matrix operations.

§xla_gpu_enable_pgle_accuracy_checker: bool

Enables strict PGLE checking. If an FDO profile is specified and latency hiding scheduler encounters missing instructions in the profile compilation will halt.

§xla_gpu_executable_warn_stuck_timeout_seconds: i32

Timeouts for RendezvousSingle stuck warning and termination.

§xla_gpu_executable_terminate_timeout_seconds: i32§xla_backend_extra_options: HashMap<String, String>

Extra options to pass to the compilation backend (e.g. LLVM); specific interpretation of these values is left to the backend.

Implementations§

Source§

impl DebugOptions

Source

pub fn xla_step_marker_location(&self) -> StepMarkerLocation

Returns the enum value of xla_step_marker_location, or the default if the field is set to an invalid enum value.

Source

pub fn set_xla_step_marker_location(&mut self, value: StepMarkerLocation)

Sets xla_step_marker_location to the provided enum value.

Source

pub fn xla_gpu_shape_checks(&self) -> ShapeChecks

Returns the enum value of xla_gpu_shape_checks, or the default if the field is set to an invalid enum value.

Source

pub fn set_xla_gpu_shape_checks(&mut self, value: ShapeChecks)

Sets xla_gpu_shape_checks to the provided enum value.

Source

pub fn xla_partitioning_algorithm(&self) -> PartitioningAlgorithm

Returns the enum value of xla_partitioning_algorithm, or the default if the field is set to an invalid enum value.

Source

pub fn set_xla_partitioning_algorithm(&mut self, value: PartitioningAlgorithm)

Sets xla_partitioning_algorithm to the provided enum value.

Source

pub fn xla_gpu_enable_command_buffer( &self, ) -> FilterMap<Cloned<Iter<'_, i32>>, fn(i32) -> Option<CommandBufferCmdType>>

Returns an iterator which yields the valid enum values contained in xla_gpu_enable_command_buffer.

Source

pub fn push_xla_gpu_enable_command_buffer( &mut self, value: CommandBufferCmdType, )

Appends the provided enum value to xla_gpu_enable_command_buffer.

Source

pub fn xla_gpu_disable_async_collectives( &self, ) -> FilterMap<Cloned<Iter<'_, i32>>, fn(i32) -> Option<CollectiveOpType>>

Returns an iterator which yields the valid enum values contained in xla_gpu_disable_async_collectives.

Source

pub fn push_xla_gpu_disable_async_collectives( &mut self, value: CollectiveOpType, )

Appends the provided enum value to xla_gpu_disable_async_collectives.

Source

pub fn xla_gpu_enable_while_loop_unrolling(&self) -> WhileLoopUnrolling

Returns the enum value of xla_gpu_enable_while_loop_unrolling, or the default if the field is set to an invalid enum value.

Source

pub fn set_xla_gpu_enable_while_loop_unrolling( &mut self, value: WhileLoopUnrolling, )

Sets xla_gpu_enable_while_loop_unrolling to the provided enum value.

Source

pub fn xla_gpu_experimental_autotune_cache_mode(&self) -> AutotuneCacheMode

Returns the enum value of xla_gpu_experimental_autotune_cache_mode, or the default if the field is set to an invalid enum value.

Source

pub fn set_xla_gpu_experimental_autotune_cache_mode( &mut self, value: AutotuneCacheMode, )

Sets xla_gpu_experimental_autotune_cache_mode to the provided enum value.

Trait Implementations§

Source§

impl Clone for DebugOptions

Source§

fn clone(&self) -> DebugOptions

Returns a duplicate of the value. Read more
1.0.0§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for DebugOptions

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for DebugOptions

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl Message for DebugOptions

Source§

fn encoded_len(&self) -> usize

Returns the encoded length of the message without a length delimiter.
Source§

fn clear(&mut self)

Clears the message, resetting all fields to their default.
Source§

fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError>
where Self: Sized,

Encodes the message to a buffer. Read more
Source§

fn encode_to_vec(&self) -> Vec<u8>
where Self: Sized,

Encodes the message to a newly allocated buffer.
Source§

fn encode_length_delimited( &self, buf: &mut impl BufMut, ) -> Result<(), EncodeError>
where Self: Sized,

Encodes the message with a length-delimiter to a buffer. Read more
Source§

fn encode_length_delimited_to_vec(&self) -> Vec<u8>
where Self: Sized,

Encodes the message with a length-delimiter to a newly allocated buffer.
Source§

fn decode(buf: impl Buf) -> Result<Self, DecodeError>
where Self: Default,

Decodes an instance of the message from a buffer. Read more
Source§

fn decode_length_delimited(buf: impl Buf) -> Result<Self, DecodeError>
where Self: Default,

Decodes a length-delimited instance of the message from the buffer.
Source§

fn merge(&mut self, buf: impl Buf) -> Result<(), DecodeError>
where Self: Sized,

Decodes an instance of the message from a buffer, and merges it into self. Read more
Source§

fn merge_length_delimited(&mut self, buf: impl Buf) -> Result<(), DecodeError>
where Self: Sized,

Decodes a length-delimited instance of the message from buffer, and merges it into self.
Source§

impl PartialEq for DebugOptions

Source§

fn eq(&self, other: &DebugOptions) -> bool

Tests for self and other values to be equal, and is used by ==.
1.0.0§

fn ne(&self, other: &Rhs) -> bool

Tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
Source§

impl StructuralPartialEq for DebugOptions

Auto Trait Implementations§

Blanket Implementations§

§

impl<T> Any for T
where T: 'static + ?Sized,

§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
§

impl<T> Borrow<T> for T
where T: ?Sized,

§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
§

impl<T> BorrowMut<T> for T
where T: ?Sized,

§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
§

impl<T> CloneToUninit for T
where T: Clone,

§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
§

impl<T> From<T> for T

§

fn from(t: T) -> T

Returns the argument unchanged.

§

impl<T, U> Into<U> for T
where U: From<T>,

§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

§

impl<T> ToOwned for T
where T: Clone,

§

type Owned = T

The resulting type after obtaining ownership.
§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.