from typing import Any, Optional
__version__: str
HAS_BURN: bool
HAS_CUBECL: bool
HAS_RUST_AI_CORE: bool
HAS_TRITTER_ACCEL: bool
class ModelConfig:
name: str
version: str
params: dict[str, Any]
def __init__(self, name: str, version: str) -> None: ...
def with_param(self, key: str, value: Any) -> ModelConfig: ...
def param(self, key: str) -> Any | None: ...
class ModelConfigBuilder:
def __init__(self) -> None: ...
def name(self, name: str) -> ModelConfigBuilder: ...
def version(self, version: str) -> ModelConfigBuilder: ...
def param(self, key: str, value: Any) -> ModelConfigBuilder: ...
def build(self) -> ModelConfig: ...
class NodeId:
value: int
def __init__(self, value: int) -> None: ...
def __hash__(self) -> int: ...
def __eq__(self, other: NodeId) -> bool: ...
def __int__(self) -> int: ...
class GraphNode:
id: NodeId
name: str
config: ModelConfig
metadata: dict[str, Any]
def __init__(self, id: NodeId, name: str, config: ModelConfig) -> None: ...
def with_metadata(self, key: str, value: Any) -> GraphNode: ...
class BuildGraph:
nodes: list[GraphNode]
edges: list[tuple[NodeId, NodeId]]
def __init__(self) -> None: ...
def add_node(self, name: str, config: ModelConfig) -> NodeId: ...
def add_edge(self, from_id: NodeId, to_id: NodeId) -> None: ...
def node_count(self) -> int: ...
def edge_count(self) -> int: ...
def stable_hash(self) -> str: ...
def __len__(self) -> int: ...
def __bool__(self) -> bool: ...
class NullBackend:
name: str
device: str
def __init__(self, device: str = "cpu") -> None: ...
@staticmethod
def cpu() -> NullBackend: ...
class TraceLevel:
Debug: TraceLevel
Info: TraceLevel
Warn: TraceLevel
Error: TraceLevel
class TraceEvent:
id: str
message: str
level: TraceLevel
span_id: str | None
trace_id: str | None
timestamp_secs: float
def __init__(
self,
id: str,
message: str,
level: TraceLevel | None = None,
span_id: str | None = None,
trace_id: str | None = None,
) -> None: ...
class InMemoryTraceSink:
def __init__(self) -> None: ...
def record(self, event: TraceEvent) -> None: ...
def events(self) -> list[TraceEvent]: ...
def __len__(self) -> int: ...
class BuildContext:
def __init__(self, backend: NullBackend, trace: InMemoryTraceSink) -> None: ...
@staticmethod
def with_null_backend() -> BuildContext: ...
class BuildPipeline:
def __init__(self) -> None: ...
@staticmethod
def standard() -> BuildPipeline: ...
@staticmethod
def for_training() -> BuildPipeline: ...
@staticmethod
def for_inference() -> BuildPipeline: ...
def with_stage(self, name: str) -> BuildPipeline: ...
def execute(self, ctx: BuildContext, graph: BuildGraph) -> BuildGraph: ...
async def execute_async(self, ctx: BuildContext, graph: BuildGraph) -> BuildGraph: ...
def __len__(self) -> int: ...
class ValidationError:
field: str
message: str
def __init__(self, field: str, message: str) -> None: ...
class NameValidator:
def __init__(self) -> None: ...
def validate(self, config: ModelConfig) -> list[ValidationError]: ...
class VersionValidator:
def __init__(self) -> None: ...
def validate(self, config: ModelConfig) -> list[ValidationError]: ...
class CompositeValidator:
def __init__(self) -> None: ...
def with_name_validator(self) -> CompositeValidator: ...
def with_version_validator(self) -> CompositeValidator: ...
def validate(self, config: ModelConfig) -> list[ValidationError]: ...