#![allow(clippy::len_zero)]
#![allow(clippy::field_reassign_with_default)]
#![allow(clippy::manual_range_contains)]
#![allow(clippy::collapsible_if)]
#![allow(clippy::only_used_in_recursion)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::or_fun_call)]
#![allow(clippy::derivable_impls)]
#![allow(clippy::manual_is_multiple_of)]
#![allow(clippy::overly_complex_bool_expr)]
#![allow(clippy::unwrap_or_default)]
pub mod causal;
pub mod cost_model;
pub mod critical_path;
pub mod execution_plan;
pub mod higher_order;
pub mod low_rank;
pub mod memo_cache;
pub mod partitioned;
pub mod step_executor;
pub use higher_order::{
FiniteDiffMethod, HessianComputer, HessianStats, JacobianComputer, JacobianConfig,
};
pub use low_rank::{
LowRankApproximation, LowRankCandidate, LowRankConfig, LowRankError, LowRankInferencePass,
LowRankPassStats, SvdResult, TruncatedSvd,
};
pub use partitioned::{
AccumulationStrategy, PartitionConfig, PartitionedError, PartitionedReducer, PartitionedStats,
};
pub use step_executor::{BreakpointCondition, IntermediateValue, StepExecutor};
pub mod async_exec;
pub mod auto_parallel;
pub mod autodiff;
pub mod backend_kind;
pub mod backend_tests;
pub mod batch;
pub mod beam_search;
pub mod cache;
pub mod cache_optimizer;
pub mod capabilities;
pub mod compilation;
pub mod constraint_propagation;
pub mod context;
pub mod debug;
pub mod diagnostics;
pub mod distributed;
mod dummy_executor;
mod dummy_tensor;
pub mod dynamic_batching;
pub mod eager;
mod error;
pub mod fusion;
pub mod gradcheck;
pub mod jit;
pub mod join_order;
pub mod learned_opt;
pub mod mcmc;
pub mod memory;
pub mod mixed_precision;
pub mod multimodel;
mod ops;
pub mod optimization;
pub mod parallel;
pub mod perfregression;
pub mod placement;
pub mod profiling;
pub mod profiling_optimizer;
pub mod pruning;
pub mod quantization;
pub mod recovery;
pub mod rewrite;
pub mod sampling;
pub mod scheduling;
pub mod shape;
pub mod simd;
pub mod sparse;
pub mod speculative;
pub mod strategy;
pub mod streaming;
pub mod symbolic_shape;
pub mod tensor_stats;
pub mod tensor_view;
pub mod trace_recording;
mod traits;
pub mod typesafe;
pub mod uncertainty;
pub mod validation;
pub mod visualization;
pub mod windowed_aggregation;
pub mod workspace;
#[cfg(test)]
mod tests;
#[cfg(test)]
mod validation_tests;
#[cfg(test)]
mod memory_tests;
#[cfg(feature = "async")]
pub use async_exec::{
AsyncConfig, AsyncExecutionError, AsyncExecutionHandle, AsyncExecutorPool, AsyncStats,
AsyncStreamResults, BoxFuture, TlAsyncBatchExecutor, TlAsyncExecutor, TlAsyncStreamExecutor,
};
pub use auto_parallel::{
AutoParallelError, AutoParallelizer, CostModel as AutoParallelCostModel, DependencyType,
NodeId as AutoParallelNodeId, NodeInfo, ParallelExecutionPlan, ParallelStage,
ParallelizationAnalysis, ParallelizationStrategy, WorkPartition,
};
pub use autodiff::{
AccumulationConfig, ClippingStrategy, CustomGradientRegistry, GradientAccumulationStrategy,
GradientAccumulator, GradientClipper, GradientConfig, GradientScaler, GradientScaling,
GradientStats, TlEnhancedAutodiff,
};
pub use backend_kind::{BackendKind, BackendKindError};
pub use backend_tests::{
assert_vec_close, print_test_summary, run_all_basic_tests, run_all_performance_tests,
test_backend_edge_cases, test_backend_einsum, test_backend_elem_binary,
test_backend_elem_unary, test_backend_forward, test_backend_large_tensors,
test_backend_memory_efficiency, test_backend_reduce, test_backend_shapes, BackendTestAdapter,
TestResult, DEFAULT_TOLERANCE,
};
pub use batch::{BatchResult, TlBatchExecutor};
pub use beam_search::{
BeamHypothesis, BeamSearchConfig, BeamSearchDecoder, BeamSearchError, BeamSearchResult,
BeamSearchStats, BeamState, BeamStepInput,
};
pub use cache::{CacheKey, CacheStats, EvictionPolicy, MemoryPool, PoolStats, TensorCache};
pub use cache_optimizer::{
AccessPattern, CacheConfig, CacheLevel, CacheMetrics, CacheOptimizer, CacheOptimizerError,
DataLayout, OptimizationStats as CacheOptimizationStats, TilingParams,
};
pub use capabilities::{BackendCapabilities, DType, DeviceType, Feature, TlCapabilities};
pub use causal::{
ate_backdoor, ate_instrumental_variable, backdoor_criterion, do_intervention,
find_backdoor_adjustment, frontdoor_criterion, propensity_score, BackdoorAdjustment,
CausalError, CausalGraph, Intervention, ObservationalData, TreatmentEffect,
};
pub use compilation::{
CacheStats as CompilationCacheStats, CompilationCache, CompilationConfig, CompilationKey,
CompilationStats, CompiledGraph, GraphCompiler, OptimizationLevel, TlCompilableExecutor,
};
pub use constraint_propagation::{
propagate_arc_consistency, solve, BinaryConstraint, ConstraintNetwork, ConstraintRelation,
CspConfig, Domain, PropagationResult, SolveStats, VarOrdering,
};
pub use context::{ExecutionContext, ExecutionHook, ExecutionPhase, ExecutionState, LoggingHook};
pub use cost_model::{
CostAwareSchedule, CostModel, CostModelConfig, FlopEstimate, GraphCostSummary,
MemoryCostEstimate, NodeCostEstimate,
};
pub use critical_path::{
critical_path, CriticalPathError, CriticalPathReport, CriticalPathResult, InferenceGraph,
MissingCostWarning, NodeId as CriticalPathNodeId, NodeLatency,
};
pub use debug::{
Breakpoint, BreakpointHit, BreakpointManager, ExecutionRecorder, ExecutionReport,
ExecutionTrace, ExecutionTracer, OperationHandle, TensorInspector, TensorStats,
TraceEntry as DebugTraceEntry, TraceSummary,
};
pub use diagnostics::{
Diagnostic, DiagnosticCollector, MemoryDiagnostic, NodeExecutionDiagnostic,
PerformanceDiagnostic, Severity, ShapeMismatchDiagnostic, SourceLocation,
TypeMismatchDiagnostic,
};
pub use distributed::{
CommunicationBackend, CommunicationOp, DataParallelCoordinator, DistributedConfig,
DistributedExecutor, DistributedPlacementPlan, DistributedStats, DummyCommunicationBackend,
ModelParallelCoordinator, ParallelismStrategy as DistributedParallelismStrategy,
PipelineParallelCoordinator, ReductionOp, ShardingSpec, TlDistributedExecutor,
};
pub use dummy_executor::DummyExecutor;
pub use dummy_tensor::DummyTensor;
pub use dynamic_batching::{
AdaptiveBatcher, BatchRequest, BatchingError, BatchingStats, DynamicBatchConfig,
DynamicBatcher, Priority, RequestMetadata, RequestQueue,
};
pub use eager::{EagerOp, EagerOps, EagerTape, TlEagerAutodiff, Variable, VariableGrad};
pub use error::ExecutorError;
pub use execution_plan::{
compute_memory_timeline, ExecutionPlan, MemoryTimelineEntry, PlanFormatter, PlanStep,
};
pub use fusion::{
FusionCandidate, FusionConfig, FusionCostModel, FusionError, FusionOptimizer, FusionPattern,
FusionStats, FusionStrategy,
};
pub use gradcheck::{
compare_gradients, numerical_gradient_central, numerical_gradient_forward, quick_check,
GradCheckConfig, GradCheckResult, GradientChecker, GradientError,
};
pub use jit::{
AdaptiveOptimizationPlan, AdaptiveOptimizer, HotPathDetector, JitCache, JitCacheEntry,
JitCacheStats, JitCompiler, JitConfig, JitEntryStats, JitKey, JitStats, SpecializationContext,
TlJitExecutor,
};
pub use join_order::{
JoinCondition, JoinOptimizerConfig, JoinOrderError, JoinOrderOptimizer, JoinPlan, JoinPlanNode,
JoinStats, Relation as JoinRelation,
};
pub use learned_opt::{
CostPrediction, FeatureVector, FusionRecommendation, LearnedOptError, LearnedOptimizer,
LearningStats, LearningStrategy, ModelType, NodeId as LearnedOptNodeId, OptimizationAction,
RewardSignal, ScheduleRecommendation, TrainingExample,
};
pub use mcmc::{
autocorrelation, compute_diagnostics, effective_sample_size, gelman_rubin, ChainDiagnostics,
GaussianProposal, HamiltonianMonteCarlo, IndependentGaussianProposal, LogProb, LogProbFn,
McmcConfig, McmcError, McmcResult, McmcRng, MetropolisHastings, Proposal,
};
pub use memo_cache::{
ExprMemoCache, MemoCacheBuilder, MemoConfig, MemoEvictionPolicy, MemoKey, MemoLookupResult,
MemoStats,
};
pub use memory::{MemoryEstimate, MemoryEstimator, TensorMemory};
pub use mixed_precision::{
GradientCheckpoint, LossScaler, LossScalerStats, LossScalingStrategy, MixedPrecisionConfig,
MixedPrecisionError, MixedPrecisionState, MixedPrecisionStats, PrecisionMode,
};
pub use multimodel::{
CascadeConfig, CoordinationStats, EnsembleConfig, EnsembleStrategy, ModelMetadata,
MultiModelCoordinator, MultiModelError, ResourceRequirements, RoutingStrategy,
TlEnsembleExecutor, TlModelRouter,
};
pub use ops::{ElemOp, ReduceOp};
pub use optimization::{
FusionOpportunity, FusionPlanner, FusionType, GraphOptimizer, OptimizationResult,
};
pub use parallel::{
LoadBalanceStats, NumaNode, NumaStrategy, ParallelConfig, ParallelError, SchedulerStats,
StealStrategy, Task, TaskId, TaskPriority, WorkStealingScheduler,
};
pub use perfregression::{
BenchmarkBaseline, BenchmarkComparison, BenchmarkConfig, BenchmarkStats, PerfRegression,
RegressionReport,
};
pub use placement::{Device, PlacementOptimizer, PlacementPlan, PlacementStrategy};
pub use profiling::{
Bottleneck, BottleneckAnalyzer, BottleneckReport, PerformanceBaseline, PerformanceComparison,
ProfileData, ProfileStatistics, Profiler, ProfilerHook, TimelineProfiler, TlProfiledExecutor,
TraceEntry,
};
pub use profiling_optimizer::{
ExecutionProfile, Hotspot, OptimizationGoal, OptimizationReport, OptimizationStrategy,
ProfilingOptimizer, ProfilingOptimizerError, TuningConfig,
};
pub use pruning::{
compute_sparsity, row_norms, MagnitudePruner, PruningConfig, PruningError, SparsityPattern,
SparsityStats,
};
pub use quantization::{
CalibrationStats, CalibrationStrategy, FakeQuantize, QuantizationConfig, QuantizationError,
QuantizationGranularity, QuantizationMode, QuantizationParams, QuantizationSummary,
QuantizationSymmetry, QuantizationType, Quantizer,
};
pub use recovery::{
Checkpoint, CheckpointManager, DegradationPolicy, FailureInfo, FallbackStrategy,
RecoveryConfig, RecoveryMetadata, RecoveryResult, RecoveryStats, RecoveryStrategy, RetryPolicy,
TlRecoverableExecutor,
};
pub use rewrite::{
CommonRules, Match, NodeId as RewriteNodeId, Pattern, ReplacementFn, RewriteEngine,
RewriteError, RewriteRule, RewriteStats, RewriteStrategy,
};
pub use sampling::{
entropy, log_softmax, perplexity, softmax, ConfigurableSampler, GreedyDecoder, SampledToken,
SamplingConfig, SamplingError, TemperatureSampler, TopKSampler, TopPSampler,
};
pub use scheduling::{ExecutionSchedule, NodeCost, Scheduler, SchedulingStrategy};
pub use shape::{DimSize, ShapeInferenceContext, TensorShape};
pub use simd::{
AlignedBuffer, CpuArchitecture, SimdCapabilities, SimdError, SimdInstructionSet,
SimdOptimizationHints,
};
pub use sparse::{
detect_sparsity, to_sparse_if_beneficial, SparseCOO, SparseCSC, SparseCSR, SparseError,
SparseFormat, SparseTensor, SparseTensorBuilder,
};
pub use speculative::{
BranchOutcome, NodeId as SpeculativeNodeId, PredictionStrategy, RollbackPolicy,
SpeculationStats, SpeculativeError, SpeculativeExecutor, SpeculativeTask,
};
pub use strategy::{
ExecutionMode, ExecutionStrategy, GradientStrategy, MemoryStrategy, ParallelismStrategy,
StrategyOptimizer,
};
pub use streaming::{
BackpressureConfig, BackpressureStrategy, ChunkIterator, ChunkMetadata, StreamProcessor,
StreamResult, StreamingConfig, StreamingConfigV2, StreamingMode, StreamingStats,
TlStreamingExecutor, WatermarkConfig,
};
pub use symbolic_shape::{
propagate_chain, propagate_einsum_shapes, ShapeError, SymbolicDim, SymbolicShape,
SymbolicShapeConstraint, SymbolicShapeEnv,
};
pub use tensor_stats::{
ActivationStatistics, AnomalyDetector, AnomalyKind, AnomalyReport, StatsError,
TensorStats as TensorStatsSummary,
};
pub use tensor_view::{
InPlaceMode, InPlaceOps, SliceSpec, TensorView, TensorViewable, ViewBuilder,
};
pub use trace_recording::{
CommunicationBottleneck, DeviceSummary, LoadBalanceMetrics, OpSummary, RecordedExecutionTrace,
RecordedTraceEntry, TraceAnalyzer, TraceRecorder,
};
pub use traits::{TlAutodiff, TlExecutor};
pub use typesafe::{
BroadcastShape, Dim, DimMul, DimOp, DimSize as TypesafeDimSize, Dyn, EinsumSpec, FixedShape,
Matrix, MatrixOps, Nat, Scalar, ShapeConstraint, ShapedTensor, Static, Tensor3D, Tensor4D,
TensorBuilder, TypedBatch, TypedInputs, TypedOutputs, TypedTensor, TypedTensorOps, Vector, D1,
D2, D3, D4, D5, D6, S, Z,
};
pub use uncertainty::{
find_optimal_temperature, temperature_scale, CalibrationBin, CalibrationMetrics,
ConfidenceInterval, IntervalMethod, MonteCarloEstimator, PredictionInterval, UncertaintyError,
UncertaintyEstimate,
};
pub use validation::{GraphValidator, ValidationResult};
pub use visualization::{
ExportFormat, GraphConfig, GraphVisualizer, TensorStatsVisualizer, TimelineConfig,
TimelineVisualizer, VisualizationFormat,
};
pub use windowed_aggregation::{
WindowAggregation, WindowConfig, WindowError, WindowResult, WindowType, WindowedAggregation,
};
pub use workspace::{
AllocationStrategy, DefragmentationResult, SharedWorkspacePool, Workspace, WorkspaceConfig,
WorkspaceError, WorkspacePool, WorkspaceStats,
};